05. Transfer Learning with TensorFlow Part 2: Fine-tuning¶
Our previous section saw how our transfer learning was able to learn so well on such little training data. This time, we'll be checking out the fine tuning option.
We will be tweaking the model's layer by unfreezing a few layers from the top (typically 1 to 3 layers), and allow it train and change based on the problem we're working with. Then continue working our way down the model.
What we're covering¶
We're going to go through the follow with TensorFlow:
- Introduce fine-tuning, a type of transfer learning to modify a pre-trained model to be more suited to your data
- Using the Keras Functional API (a different way to build models in Keras)
- Using a smaller dataset to experiment faster (e.g. 1-10% of training smaples of 10 classes of food)
- Data augmentation (how to make your dataset more diverse without adding more data)
- Running a series of modelling experiments on our Food Vision data
- Model 0: a transfer learning model using the Keras Functional API
- Model 1: a feature extraction transfer learning model on 1% of the data with data augmentation
- Model 2: a feature extraction transfer learning model on 10% of the data with data augmentation
- Model 3: a fine-tuned transfer learning model on 10% of the data
- Model 4: a fine-tuned transfer learninbg model on 100% of the data
- Introduce the ModelCheckpoint callback to save intermediate training results
- Compare model experiment results using TensorBoard
import datetime
print(f'Notebook last run (end-to-end): {datetime.datetime.now()}')
Notebook last run (end-to-end): 2025-09-14 13:10:33.836532
import tensorflow as tf
print(f'TensorFlow version: {tf.__version__}')
TensorFlow version: 2.20.0
!nvidia-smi
Sun Sep 14 13:19:44 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.94 Driver Version: 560.94 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce GTX 1060 6GB WDDM | 00000000:0A:00.0 On | N/A |
| 30% 54C P0 30W / 120W | 1518MiB / 6144MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3428 C ...elchupacabra\App\Bandicam\bdcam.exe N/A |
| 0 N/A N/A 3876 C+G ...2txyewy\StartMenuExperienceHost.exe N/A |
| 0 N/A N/A 4868 C+G C:\Windows\explorer.exe N/A |
| 0 N/A N/A 5980 C+G ...5n1h2txyewy\ShellExperienceHost.exe N/A |
| 0 N/A N/A 6044 C+G ...n\139.0.3405.102\msedgewebview2.exe N/A |
| 0 N/A N/A 6368 C+G ....Search_cw5n1h2txyewy\SearchApp.exe N/A |
| 0 N/A N/A 7528 C+G ...CBS_cw5n1h2txyewy\TextInputHost.exe N/A |
| 0 N/A N/A 8980 C+G ...GeForce Experience\NVIDIA Share.exe N/A |
| 0 N/A N/A 9260 C+G ... Files\Elgato\WaveLink\WaveLink.exe N/A |
| 0 N/A N/A 11660 C+G ...oogle\Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 16828 C+G ...1.0_x64__8wekyb3d8bbwe\Video.UI.exe N/A |
| 0 N/A N/A 17812 C+G ...remium\win64\bin\HarmonyPremium.exe N/A |
| 0 N/A N/A 17968 C+G ....Search_cw5n1h2txyewy\SearchApp.exe N/A |
| 0 N/A N/A 19188 C+G ...soft Office\root\Office16\EXCEL.EXE N/A |
| 0 N/A N/A 20236 C+G ...dobe\Adobe Animate 2021\Animate.exe N/A |
| 0 N/A N/A 20360 C+G ...t.LockApp_cw5n1h2txyewy\LockApp.exe N/A |
| 0 N/A N/A 21080 C+G ...\cef\cef.win7x64\steamwebhelper.exe N/A |
| 0 N/A N/A 24844 C+G ...ejd91yc\AdobeNotificationClient.exe N/A |
| 0 N/A N/A 25424 C+G X:\Mozilla_Thunderbird\thunderbird.exe N/A |
| 0 N/A N/A 25892 C+G ...oogle\Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 26864 C+G ...GeForce Experience\NVIDIA Share.exe N/A |
| 0 N/A N/A 29604 C+G ...siveControlPanel\SystemSettings.exe N/A |
| 0 N/A N/A 36384 C+G ...3.0_x64__cv1g1gvanyjgm\WhatsApp.exe N/A |
| 0 N/A N/A 37408 C+G ...al\Discord\app-1.0.9205\Discord.exe N/A |
| 0 N/A N/A 39692 C+G ...cal\Microsoft\OneDrive\OneDrive.exe N/A |
| 0 N/A N/A 40840 C+G ...ekyb3d8bbwe\PhoneExperienceHost.exe N/A |
| 0 N/A N/A 41576 C+G X:\Microsoft VS Code\Code.exe N/A |
| 0 N/A N/A 44280 C+G ...cal\Microsoft\OneDrive\OneDrive.exe N/A |
+-----------------------------------------------------------------------------------------+
# Install it first if you haven't
!pip install scikit-learn
Collecting scikit-learn Downloading scikit_learn-1.7.2-cp310-cp310-win_amd64.whl.metadata (11 kB) Requirement already satisfied: numpy>=1.22.0 in x:\miniconda3\envs\tfenv\lib\site-packages (from scikit-learn) (2.1.3) Collecting scipy>=1.8.0 (from scikit-learn) Using cached scipy-1.15.3-cp310-cp310-win_amd64.whl.metadata (60 kB) Collecting joblib>=1.2.0 (from scikit-learn) Downloading joblib-1.5.2-py3-none-any.whl.metadata (5.6 kB) Collecting threadpoolctl>=3.1.0 (from scikit-learn) Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB) Downloading scikit_learn-1.7.2-cp310-cp310-win_amd64.whl (8.9 MB) ---------------------------------------- 0.0/8.9 MB ? eta -:--:-- ---------- ----------------------------- 2.4/8.9 MB 13.4 MB/s eta 0:00:01 -------------------- ------------------- 4.5/8.9 MB 11.7 MB/s eta 0:00:01 --------------------------------- ------ 7.3/8.9 MB 12.2 MB/s eta 0:00:01 ---------------------------------------- 8.9/8.9 MB 12.0 MB/s eta 0:00:00 Downloading joblib-1.5.2-py3-none-any.whl (308 kB) Using cached scipy-1.15.3-cp310-cp310-win_amd64.whl (41.3 MB) Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB) Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] ---------- ----------------------------- 1/4 [scipy] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] -------------------- ------------------- 2/4 [joblib] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ------------------------------ --------- 3/4 [scikit-learn] ---------------------------------------- 4/4 [scikit-learn] Successfully installed joblib-1.5.2 scikit-learn-1.7.2 scipy-1.15.3 threadpoolctl-3.6.0
# get helper_functions.py script from course Github
!curl -O https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
# import helper functions we're going to use
import sys, os
sys.path.append(os.getcwd())
import sklearn
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, walk_through_dir
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 10246 100 10246 0 0 34740 0 --:--:-- --:--:-- --:--:-- 35089
Now we've got a bunch of helper functions, we can call the functions rather than rewriting from scratch.
10 Food Classes: Working with Less Data¶
We saw how 10% of data can still train up an effective model. In this notebook, we're still continuing with the same method but with even smaller data sets, whilst using pretrained models from tf.keras.applications modules, but also fine-tune them.
We'll also practice use of a dataloader function called image_dataset_from_directory(), which is part of tf.keras.utils module.
Finally, we will also play with the Keras Functional API for building deep learning models. It's a more flexible way to create models than the tf.keras.Sequential API.
We'll explore each of these in more detail as we go. But let's start with downloading some data.
Let's check out the 10% food data of 10 by checking it's directories.
# walk through 10 percent data directory of 10 classes of food
walk_through_dir('10_food_classes_10_percent')
There are 2 directories and 0 images in '10_food_classes_10_percent'. There are 10 directories and 0 images in '10_food_classes_10_percent\test'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\ice_cream'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\chicken_curry'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\steak'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\sushi'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\chicken_wings'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\hamburger'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\pizza'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\ramen'. There are 0 directories and 250 images in '10_food_classes_10_percent\test\fried_rice'. There are 10 directories and 0 images in '10_food_classes_10_percent\train'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\ice_cream'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\chicken_curry'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\steak'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\sushi'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\chicken_wings'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\grilled_salmon'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\hamburger'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\pizza'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\ramen'. There are 0 directories and 75 images in '10_food_classes_10_percent\train\fried_rice'.
We can see the 10% data, represented by 75 images towards each class, while test remains the same at 250 results per class.
Let's define our training and test filepaths.
# create training and test directories
train_dir = '10_food_classes_10_percent/train/'
test_dir = '10_food_classes_10_percent/test/'
With our image data, we need to find ways to load it into TensorFlow compatible format. We used ImageDataGenerator for this, but has been legacy. Making it not ideal for future model builds.
We will move onto tf.keras.utils.image_dataset_from_directory(), which expects the following directory format:
10_food_classes_10_percent <- top level folder
└───train <- training images
│ └───pizza
│ │ │ 1008104.jpg
│ │ │ 1638227.jpg
│ │ │ ...
│ └───steak
│ │ 1000205.jpg
│ │ 1647351.jpg
│ │ ...
│
└───test <- testing images
│ └───pizza
│ │ │ 1001116.jpg
│ │ │ 1507019.jpg
│ │ │ ...
│ └───steak
│ │ 100274.jpg
│ │ 1653815.jpg
│ │ ...
The benefits of using tf.keras.utils.image_dataset_from_directory() rather than ImageDataGenerator - is it creates tf.data.Dataset object rather than a generator.
So what is tf.data.Dataset? Its an API which is much more efficient and faster than the API generated from ImageDataGenerator.
APIs are a set of tools, functions, and rules that allow you to interact with a library/software. Basically its a way for you to 'talk to' a software and use it's functionality, whilst not needing to know everything about the software.
It's this communication from tf.data.Dataset that makes it so much faster, as it loads data in parallel with each other in TensorFlow. Contrasting the other, that loads one by one from using a python generator.
# create data inputs
import tensorflow as tf
IMG_SIZE = (224,224) # define size
train_data_10_percent = tf.keras.utils.image_dataset_from_directory(directory=train_dir,
image_size=IMG_SIZE,
label_mode='categorical',
batch_size=32)
test_data_10_percent = tf.keras.utils.image_dataset_from_directory(directory=test_dir,
image_size=IMG_SIZE,
label_mode='categorical')
Found 750 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Looks like our dataloader has found the correct number of images in each dataset :)
Now, the main parameters to be concerned of are:
directory- filepath of target directory we're loading images fromimage_size- target size of images (height, width)batch_size- batch size of images we're loading in - (default for in general problems is 32)
We should see the datatype being BatchDataset with shapes relating to our data.
# check the training datatype
train_data_10_percent
<_PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>
In the above output:
(None, 224, 224, 3)refers to the tensor shape of our image.None> batch size,224, 224> height, width,3> colour channels (red, green, blue).(None, 10)refers to the tensor shape of our labels.None> batch size,10> number of classes in our problem.- Both image and labels are of datatype
tf.float32
batch_size is set as None, as it's only used during model training. Think of it as a simple placeholder, which will be filled withe the batch_size parameter of 32.
With tf.data.Dataset, it has benefits of associated methods that come with it, like find the name of classes with the class_names attribute.
# check out the class names of our dataset
train_data_10_percent.class_names
['chicken_curry', 'chicken_wings', 'fried_rice', 'grilled_salmon', 'hamburger', 'ice_cream', 'pizza', 'ramen', 'steak', 'sushi']
Or if we wanted to see an example batch of data, we could use the take() method.
# see an example batch of data
for images, labels in train_data_10_percent.take(1):
print(images, labels)
tf.Tensor( [[[[5.88265305e+01 3.65714302e+01 2.04081631e+00] [5.46683655e+01 3.16428566e+01 2.11734700e+00] [5.10051003e+01 3.00663261e+01 3.57142878e+00] ... [4.57140684e+00 2.57140684e+00 5.99993467e+00] [3.78572750e+00 1.81124234e+00 2.73469758e+00] [6.19899225e+00 5.19899225e+00 3.19899225e+00]] [[7.34540787e+01 4.85714302e+01 1.50000000e+01] [7.05714264e+01 4.52857132e+01 1.44897957e+01] [6.80153046e+01 4.38724480e+01 1.82602043e+01] ... [2.19895935e+00 1.98959351e-01 3.59687805e+00] [4.22451782e+00 2.36228490e+00 2.93877745e+00] [1.09286413e+01 9.92864132e+00 7.12245369e+00]] [[8.49387741e+01 5.57908173e+01 2.55051003e+01] [8.39438705e+01 5.45000000e+01 2.62704067e+01] [8.30459137e+01 5.37346916e+01 2.85459175e+01] ... [4.35714245e+00 2.35714245e+00 5.35714245e+00] [9.21433735e+00 8.21433735e+00 6.07145405e+00] [1.79235649e+01 1.69235649e+01 1.26429262e+01]] ... [[9.21935558e+00 9.21935558e+00 1.12193556e+01] [8.72953987e+00 8.72953987e+00 1.07295399e+01] [7.02037716e+00 7.02037716e+00 9.02037716e+00] ... [1.57857056e+01 1.17857056e+01 1.12142334e+01] [1.40000000e+01 1.30000000e+01 9.00000000e+00] [1.50765409e+01 1.40765409e+01 9.07654095e+00]] [[5.42855406e+00 5.42855406e+00 7.42855406e+00] [2.21425056e+00 2.21425056e+00 4.21425056e+00] [3.28569698e+00 3.28569698e+00 5.28569698e+00] ... [1.50000000e+01 1.10000000e+01 1.04285278e+01] [1.40000000e+01 1.30000000e+01 9.00000000e+00] [1.50255060e+01 1.40255060e+01 9.02550602e+00]] [[4.77040958e+00 4.77040958e+00 4.77040958e+00] [5.23978758e+00 5.23978758e+00 7.23978758e+00] [3.36734438e+00 3.36734438e+00 5.36734438e+00] ... [1.64285889e+01 1.24285889e+01 1.18571167e+01] [1.48826323e+01 1.38826323e+01 9.88263226e+00] [1.18979216e+01 1.08979216e+01 5.89792156e+00]]] [[[3.73571434e+01 3.63571434e+01 5.23571434e+01] [3.73571434e+01 3.63571434e+01 5.23571434e+01] [3.68622437e+01 3.58622437e+01 5.18622437e+01] ... [2.04933334e+02 2.18862015e+02 2.01443420e+02] [1.62163055e+02 1.90479446e+02 1.41137421e+02] [1.27586342e+02 1.65683350e+02 9.13973923e+01]] [[3.68571434e+01 3.78571434e+01 5.58571434e+01] [3.59285736e+01 3.69285736e+01 5.49285736e+01] [3.78010216e+01 3.98010216e+01 5.48010216e+01] ... [2.22958923e+02 2.33688644e+02 2.24703812e+02] [1.89714081e+02 2.11280457e+02 1.78984421e+02] [1.53545502e+02 1.82193497e+02 1.30596390e+02]] [[3.61377563e+01 3.91377563e+01 5.61377563e+01] [3.74132652e+01 4.04132652e+01 5.74132652e+01] [3.71224480e+01 4.01224480e+01 5.71224480e+01] ... [2.39556015e+02 2.43642838e+02 2.45387665e+02] [2.25489700e+02 2.36362183e+02 2.28178452e+02] [2.05867050e+02 2.21938507e+02 2.04300659e+02]] ... [[9.80656357e+01 1.03065636e+02 1.45065643e+02] [1.37637436e+02 1.42208908e+02 1.84423172e+02] [9.08619919e+01 9.31732483e+01 1.35867172e+02] ... [1.21566559e+02 1.07071625e+02 1.11474739e+02] [1.28199265e+02 1.13199265e+02 1.17770737e+02] [1.39362610e+02 1.24362602e+02 1.28934082e+02]] [[5.87757721e+01 6.38472137e+01 1.01653297e+02] [6.94793167e+01 7.26170883e+01 1.13198677e+02] [9.11937180e+01 9.41937180e+01 1.37622284e+02] ... [1.34188995e+02 1.21188988e+02 1.28188995e+02] [1.50576630e+02 1.35505188e+02 1.42719513e+02] [1.54331696e+02 1.38663345e+02 1.47668365e+02]] [[1.41307678e+02 1.46307678e+02 1.77593323e+02] [5.71127129e+01 6.11127129e+01 9.79698639e+01] [4.59999657e+01 4.89999657e+01 9.18672562e+01] ... [1.47806183e+02 1.34806183e+02 1.43806183e+02] [1.54454163e+02 1.41454163e+02 1.51454163e+02] [1.65785889e+02 1.49785889e+02 1.60785889e+02]]] [[[1.63887756e+02 1.65887756e+02 1.44887756e+02] [1.80382660e+02 1.79382660e+02 1.59382660e+02] [1.59581635e+02 1.62795914e+02 1.45663269e+02] ... [1.38356964e+02 1.42928436e+02 1.37275421e+02] [1.23331627e+02 1.29545959e+02 1.27474510e+02] [1.20596802e+02 1.30183563e+02 1.29183563e+02]] [[1.69051025e+02 1.68051025e+02 1.50051025e+02] [1.82280624e+02 1.81280624e+02 1.61280624e+02] [1.65816330e+02 1.69030609e+02 1.52658173e+02] ... [1.39300888e+02 1.41285599e+02 1.35999908e+02] [1.17556183e+02 1.21847031e+02 1.20704155e+02] [1.18361900e+02 1.25571182e+02 1.25168091e+02]] [[1.63005112e+02 1.61576538e+02 1.43790817e+02] [1.78515305e+02 1.77484695e+02 1.59500000e+02] [1.58908157e+02 1.61122452e+02 1.47765305e+02] ... [1.40887711e+02 1.41459183e+02 1.36459183e+02] [1.27321396e+02 1.29479584e+02 1.26576546e+02] [1.23515221e+02 1.28076508e+02 1.28367294e+02]] ... [[8.90663528e+01 4.30663567e+01 1.90663567e+01] [8.72703857e+01 4.12703857e+01 1.72703876e+01] [8.71683273e+01 4.11683273e+01 1.71683273e+01] ... [1.65091812e+02 1.03091820e+02 2.60918179e+01] [1.61775528e+02 9.97755203e+01 2.27755241e+01] [1.66260300e+02 1.04260292e+02 2.72602921e+01]] [[8.82142563e+01 4.22142601e+01 1.82142601e+01] [8.69948959e+01 4.09948959e+01 1.69948978e+01] [8.90000000e+01 4.30000000e+01 1.90000000e+01] ... [1.62489761e+02 1.00489754e+02 2.34897575e+01] [1.66525421e+02 1.04525429e+02 2.75254288e+01] [1.64357300e+02 1.02357300e+02 2.53572998e+01]] [[8.67143555e+01 4.07143555e+01 1.67143555e+01] [8.86429291e+01 4.26429291e+01 1.86429272e+01] [9.01428909e+01 4.41428909e+01 2.01428909e+01] ... [1.65112228e+02 1.03112228e+02 2.61122284e+01] [1.59831635e+02 9.78316269e+01 2.08316307e+01] [1.63831390e+02 1.01831383e+02 2.48313828e+01]]] ... [[[2.22500000e+02 2.20500000e+02 2.07500000e+02] [2.22591843e+02 2.20591843e+02 2.07591843e+02] [2.25346939e+02 2.23346939e+02 2.10346939e+02] ... [2.23300995e+02 2.22872467e+02 2.10872467e+02] [2.20454071e+02 2.20454071e+02 2.08454071e+02] [2.21153107e+02 2.21153107e+02 2.09153107e+02]] [[2.21714279e+02 2.19714279e+02 2.06714279e+02] [2.20142853e+02 2.18142853e+02 2.05142853e+02] [2.18913269e+02 2.16913269e+02 2.04127548e+02] ... [2.20224472e+02 2.18224472e+02 2.06224472e+02] [2.19423462e+02 2.19423462e+02 2.07423462e+02] [2.19994888e+02 2.19994888e+02 2.07994888e+02]] [[2.19581635e+02 2.20581635e+02 2.06153061e+02] [2.19443878e+02 2.20443878e+02 2.06443878e+02] [2.22188782e+02 2.19928574e+02 2.08571442e+02] ... [2.26096985e+02 2.23668457e+02 2.11882721e+02] [2.22499969e+02 2.22071396e+02 2.10071396e+02] [2.21627640e+02 2.21627640e+02 2.09627640e+02]] ... [[2.20653030e+02 2.21653030e+02 2.07653030e+02] [2.20301102e+02 2.21301102e+02 2.07301102e+02] [2.15642853e+02 2.16642853e+02 2.02642853e+02] ... [2.18428528e+02 2.18428528e+02 2.06428528e+02] [2.18627487e+02 2.18627487e+02 2.06627487e+02] [2.21984863e+02 2.21984863e+02 2.09984863e+02]] [[2.18887802e+02 2.19887802e+02 2.05887802e+02] [2.22443802e+02 2.23443802e+02 2.09443802e+02] [2.22530579e+02 2.23530579e+02 2.09530579e+02] ... [2.20698990e+02 2.20698990e+02 2.08698990e+02] [2.19923462e+02 2.19923462e+02 2.07923462e+02] [2.19903137e+02 2.19903137e+02 2.07903137e+02]] [[2.20004974e+02 2.21004974e+02 2.07004974e+02] [2.20234894e+02 2.21234894e+02 2.07234894e+02] [2.17494919e+02 2.18494919e+02 2.04494919e+02] ... [2.17505081e+02 2.17505081e+02 2.05505081e+02] [2.22596985e+02 2.22596985e+02 2.10596985e+02] [2.20071533e+02 2.20071533e+02 2.08071533e+02]]] [[[1.35867348e+01 8.58673477e+00 2.58673477e+00] [1.10714283e+01 8.07142830e+00 1.07142830e+00] [1.07857141e+01 8.21428585e+00 3.21428585e+00] ... [1.24285278e+01 9.00000000e+00 2.21426392e+00] [1.20714417e+01 9.07144165e+00 2.07144165e+00] [1.20000000e+01 1.10000000e+01 6.00000000e+00]] [[1.30255098e+01 8.02550983e+00 2.02551007e+00] [1.19336739e+01 8.93367386e+00 1.93367362e+00] [9.01530647e+00 8.01530647e+00 3.01530600e+00] ... [1.30000000e+01 1.00000000e+01 5.00000000e+00] [1.30663385e+01 1.00663385e+01 3.06633878e+00] [1.20000000e+01 1.20000000e+01 4.00000000e+00]] [[1.30000000e+01 8.00000000e+00 2.00000000e+00] [1.21989794e+01 9.19897938e+00 2.19897985e+00] [9.57142830e+00 9.00000000e+00 3.78571415e+00] ... [1.33571644e+01 1.07857361e+01 5.83164978e+00] [1.50000000e+01 1.20000000e+01 7.00000000e+00] [1.40000000e+01 1.10000000e+01 4.00000000e+00]] ... [[2.42859316e+00 8.85712147e+00 0.00000000e+00] [5.05612373e+00 8.10203743e+00 1.53045058e-02] [4.28567123e+00 6.81120777e+00 0.00000000e+00] ... [1.45714722e+01 1.00000000e+01 5.78573608e+00] [1.50000000e+01 1.00000000e+01 4.00000000e+00] [1.46428223e+01 9.64282227e+00 3.64282227e+00]] [[6.00000000e+00 9.00000000e+00 0.00000000e+00] [6.06632519e+00 9.06632519e+00 6.63253441e-02] [6.27041864e+00 9.27041817e+00 2.70418555e-01] ... [1.57296019e+01 1.11581297e+01 6.94386578e+00] [1.40663376e+01 9.06633759e+00 3.06633782e+00] [1.50255175e+01 1.00255175e+01 4.02551746e+00]] [[5.28564453e+00 6.28564453e+00 6.42822266e-01] [6.57139397e+00 7.57139397e+00 1.57139397e+00] [7.00000000e+00 1.00000000e+01 3.00000000e+00] ... [1.35050888e+01 8.93361664e+00 4.71935272e+00] [1.30714417e+01 8.07144165e+00 2.07144165e+00] [1.53571777e+01 1.03571777e+01 4.35717773e+00]]] [[[2.07000000e+02 1.98000000e+02 1.83000000e+02] [2.06642853e+02 1.97642853e+02 1.82642853e+02] [2.08076523e+02 1.98862244e+02 1.84505096e+02] ... [1.86719421e+02 1.75147949e+02 1.61576477e+02] [1.82571396e+02 1.70571396e+02 1.56571396e+02] [1.80943893e+02 1.68943893e+02 1.54943893e+02]] [[2.08928574e+02 1.99928574e+02 1.84928574e+02] [2.09857147e+02 2.00857147e+02 1.85857147e+02] [2.07785721e+02 1.98571426e+02 1.84214279e+02] ... [1.91372467e+02 1.82372467e+02 1.67372467e+02] [1.91566330e+02 1.79566330e+02 1.63566330e+02] [1.91831665e+02 1.79831665e+02 1.63831665e+02]] [[2.10214279e+02 2.01000000e+02 1.86642853e+02] [2.10000000e+02 2.00785721e+02 1.86428574e+02] [2.08831635e+02 1.98831635e+02 1.86831635e+02] ... [1.66739822e+02 1.57831650e+02 1.40785736e+02] [1.67000000e+02 1.58214279e+02 1.40785706e+02] [1.69928604e+02 1.58785751e+02 1.42142883e+02]] ... [[2.06561234e+02 2.01561234e+02 1.98561234e+02] [2.04744919e+02 1.99744919e+02 1.96744919e+02] [2.06112274e+02 2.01112274e+02 1.98112274e+02] ... [1.88428528e+02 1.74428528e+02 1.61428528e+02] [1.88413223e+02 1.74413223e+02 1.61413223e+02] [1.89000000e+02 1.75000000e+02 1.62000000e+02]] [[2.03260208e+02 1.98260208e+02 1.95260208e+02] [2.07928574e+02 2.02928574e+02 1.99928574e+02] [2.04943863e+02 1.99943863e+02 1.96943863e+02] ... [1.92928497e+02 1.78928497e+02 1.65928497e+02] [1.89857117e+02 1.75857117e+02 1.62857117e+02] [1.88000000e+02 1.74000000e+02 1.61000000e+02]] [[2.05382721e+02 2.00382721e+02 1.97382721e+02] [2.05617310e+02 2.00617310e+02 1.97617310e+02] [2.05561279e+02 2.00561279e+02 1.97561279e+02] ... [1.88428711e+02 1.74428711e+02 1.61428711e+02] [1.88071533e+02 1.74071533e+02 1.61071533e+02] [1.88714355e+02 1.74714355e+02 1.61714355e+02]]]], shape=(32, 224, 224, 3), dtype=float32) tf.Tensor( [[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]], shape=(32, 10), dtype=float32)
Model 0: Building a transfer learning model using the Keras Functional API¶
To build our model, we will use tf.keras.applications module, as it has pretrained models on ImageNet, in addition to Keras Functional API.
Our following steps are:
1. Instantiate a pretrained base model object by choosing a target model such as `EfficientNetV2B0` from `tf.keras.applications.efficientnet_v2`, setting the `include_top` parameter to `False`. (We'll create our own top, since our output shape is different from the base model)
2. Set base model's `trainable` attribute to `False` to freeze all the weights in the pretrained model.
3. Define an input layer for our model, what shape of data should our model expect?
4. [Optional] Normalize the inputs to our model if required. Some computer models such as `ResNetV250` require their inputs to be between 0 and 1.
Note:
EfficientNetandEfficientNetV2models intf.keras.applicationsmodule do not require images to be normalized on input, but many other models do require it.
5. Pass the inputs to the base model
6. Pool the outputs of the base mnodel into shape, that's compatible with the output activation layer (Our image is a 3D shape, while our output shape needs to become 1D to match it's shape). This can be done through `average pooling` or `max pooling` at times.
7. Create an output activation layer using `tf.keras.layers.Dense()` with the appropriate activation function and number of neurons.
8. Combine the inputs and outputs layer into a model using `tf.keras.Model()`.
9. Compile the model using the appropriate loss function and choose of optimizer.
10. Fit the model for desired number and of epochs with callbacks.
Before we start, let's see it in practice
# 1. create base model with tf.keras.applications
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
# 2. Freeze the base model (so the pre-learned patterns remain)
base_model.trainable = False
# 3. Create inputs into the base model
inputs = tf.keras.layers.Input(shape=(224,224,3), name='input_layer')
# 4. [optional] If using ResNet50V2, add this to speed up convergence, remove for EfficientNetV2
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs)
# 5. Pass the inputs to the base_model (note: using tf.keras.applications, EfficientNetV2 don't have to be normalized)
x = base_model(inputs)
# 6. Average pool the outputs of the base model (aggregate all the most important information, reduce number of computations)
x = tf.keras.layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x)
# 7. Create the output activation layer
outputs = tf.keras.layers.Dense(10, activation='softmax', name='output_layer')(x)
# 8. Combine the inputs with the outpus into the model
model_0 = tf.keras.Model(inputs,outputs)
# 9. Compile the model
model_0.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
# 10. Fit the model (we use less steps for validation so it's faster)
history_10_percent = model_0.fit(train_data_10_percent,
epochs=5,
steps_per_epoch=len(train_data_10_percent),
validation_data=test_data_10_percent,
# go through less of the validation data so epochs are faster
validation_steps=int(0.25 * len(test_data_10_percent)),
# track our model's training logs for visualization later
callbacks=[create_tensorboard_callback('transfer_learning','10_percent_feature_extract')])
Saving TensorBoard log files to: transfer_learning/10_percent_feature_extract/20251006-204716 Epoch 1/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 26s 698ms/step - accuracy: 0.4120 - loss: 1.8944 - val_accuracy: 0.7303 - val_loss: 1.2917 Epoch 2/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 13s 564ms/step - accuracy: 0.7400 - loss: 1.1468 - val_accuracy: 0.8224 - val_loss: 0.8740 Epoch 3/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 13s 560ms/step - accuracy: 0.8240 - loss: 0.8384 - val_accuracy: 0.8306 - val_loss: 0.7084 Epoch 4/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 13s 570ms/step - accuracy: 0.8520 - loss: 0.6846 - val_accuracy: 0.8618 - val_loss: 0.6001 Epoch 5/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 14s 584ms/step - accuracy: 0.8707 - loss: 0.5868 - val_accuracy: 0.8635 - val_loss: 0.5423
The model has done quite well with only 5 epochs.
Keep in mind, we have used feature extraction transfer learning, similar to what we have done on workbook 04.
We have also used Keras Functional API to build the model over Sequential API. The benefits of either may not seem clear right now, but when you create more sophisticated model, you'll probably want to use Functional API. It's important to have exposure to other ways of building models.
To see benefits and use cases for Function API vs. Sequential API, check out TensorFlow Functional API documentation.
We'll check the layers in our transplanted model. We will start with the base.
# check layers in our base model
for layer_number, layer in enumerate(base_model.layers):
print(layer_number, layer.name)
0 input_layer_1 1 rescaling_1 2 normalization_1 3 stem_conv 4 stem_bn 5 stem_activation 6 block1a_project_conv 7 block1a_project_bn 8 block1a_project_activation 9 block2a_expand_conv 10 block2a_expand_bn 11 block2a_expand_activation 12 block2a_project_conv 13 block2a_project_bn 14 block2b_expand_conv 15 block2b_expand_bn 16 block2b_expand_activation 17 block2b_project_conv 18 block2b_project_bn 19 block2b_drop 20 block2b_add 21 block3a_expand_conv 22 block3a_expand_bn 23 block3a_expand_activation 24 block3a_project_conv 25 block3a_project_bn 26 block3b_expand_conv 27 block3b_expand_bn 28 block3b_expand_activation 29 block3b_project_conv 30 block3b_project_bn 31 block3b_drop 32 block3b_add 33 block4a_expand_conv 34 block4a_expand_bn 35 block4a_expand_activation 36 block4a_dwconv2 37 block4a_bn 38 block4a_activation 39 block4a_se_squeeze 40 block4a_se_reshape 41 block4a_se_reduce 42 block4a_se_expand 43 block4a_se_excite 44 block4a_project_conv 45 block4a_project_bn 46 block4b_expand_conv 47 block4b_expand_bn 48 block4b_expand_activation 49 block4b_dwconv2 50 block4b_bn 51 block4b_activation 52 block4b_se_squeeze 53 block4b_se_reshape 54 block4b_se_reduce 55 block4b_se_expand 56 block4b_se_excite 57 block4b_project_conv 58 block4b_project_bn 59 block4b_drop 60 block4b_add 61 block4c_expand_conv 62 block4c_expand_bn 63 block4c_expand_activation 64 block4c_dwconv2 65 block4c_bn 66 block4c_activation 67 block4c_se_squeeze 68 block4c_se_reshape 69 block4c_se_reduce 70 block4c_se_expand 71 block4c_se_excite 72 block4c_project_conv 73 block4c_project_bn 74 block4c_drop 75 block4c_add 76 block5a_expand_conv 77 block5a_expand_bn 78 block5a_expand_activation 79 block5a_dwconv2 80 block5a_bn 81 block5a_activation 82 block5a_se_squeeze 83 block5a_se_reshape 84 block5a_se_reduce 85 block5a_se_expand 86 block5a_se_excite 87 block5a_project_conv 88 block5a_project_bn 89 block5b_expand_conv 90 block5b_expand_bn 91 block5b_expand_activation 92 block5b_dwconv2 93 block5b_bn 94 block5b_activation 95 block5b_se_squeeze 96 block5b_se_reshape 97 block5b_se_reduce 98 block5b_se_expand 99 block5b_se_excite 100 block5b_project_conv 101 block5b_project_bn 102 block5b_drop 103 block5b_add 104 block5c_expand_conv 105 block5c_expand_bn 106 block5c_expand_activation 107 block5c_dwconv2 108 block5c_bn 109 block5c_activation 110 block5c_se_squeeze 111 block5c_se_reshape 112 block5c_se_reduce 113 block5c_se_expand 114 block5c_se_excite 115 block5c_project_conv 116 block5c_project_bn 117 block5c_drop 118 block5c_add 119 block5d_expand_conv 120 block5d_expand_bn 121 block5d_expand_activation 122 block5d_dwconv2 123 block5d_bn 124 block5d_activation 125 block5d_se_squeeze 126 block5d_se_reshape 127 block5d_se_reduce 128 block5d_se_expand 129 block5d_se_excite 130 block5d_project_conv 131 block5d_project_bn 132 block5d_drop 133 block5d_add 134 block5e_expand_conv 135 block5e_expand_bn 136 block5e_expand_activation 137 block5e_dwconv2 138 block5e_bn 139 block5e_activation 140 block5e_se_squeeze 141 block5e_se_reshape 142 block5e_se_reduce 143 block5e_se_expand 144 block5e_se_excite 145 block5e_project_conv 146 block5e_project_bn 147 block5e_drop 148 block5e_add 149 block6a_expand_conv 150 block6a_expand_bn 151 block6a_expand_activation 152 block6a_dwconv2 153 block6a_bn 154 block6a_activation 155 block6a_se_squeeze 156 block6a_se_reshape 157 block6a_se_reduce 158 block6a_se_expand 159 block6a_se_excite 160 block6a_project_conv 161 block6a_project_bn 162 block6b_expand_conv 163 block6b_expand_bn 164 block6b_expand_activation 165 block6b_dwconv2 166 block6b_bn 167 block6b_activation 168 block6b_se_squeeze 169 block6b_se_reshape 170 block6b_se_reduce 171 block6b_se_expand 172 block6b_se_excite 173 block6b_project_conv 174 block6b_project_bn 175 block6b_drop 176 block6b_add 177 block6c_expand_conv 178 block6c_expand_bn 179 block6c_expand_activation 180 block6c_dwconv2 181 block6c_bn 182 block6c_activation 183 block6c_se_squeeze 184 block6c_se_reshape 185 block6c_se_reduce 186 block6c_se_expand 187 block6c_se_excite 188 block6c_project_conv 189 block6c_project_bn 190 block6c_drop 191 block6c_add 192 block6d_expand_conv 193 block6d_expand_bn 194 block6d_expand_activation 195 block6d_dwconv2 196 block6d_bn 197 block6d_activation 198 block6d_se_squeeze 199 block6d_se_reshape 200 block6d_se_reduce 201 block6d_se_expand 202 block6d_se_excite 203 block6d_project_conv 204 block6d_project_bn 205 block6d_drop 206 block6d_add 207 block6e_expand_conv 208 block6e_expand_bn 209 block6e_expand_activation 210 block6e_dwconv2 211 block6e_bn 212 block6e_activation 213 block6e_se_squeeze 214 block6e_se_reshape 215 block6e_se_reduce 216 block6e_se_expand 217 block6e_se_excite 218 block6e_project_conv 219 block6e_project_bn 220 block6e_drop 221 block6e_add 222 block6f_expand_conv 223 block6f_expand_bn 224 block6f_expand_activation 225 block6f_dwconv2 226 block6f_bn 227 block6f_activation 228 block6f_se_squeeze 229 block6f_se_reshape 230 block6f_se_reduce 231 block6f_se_expand 232 block6f_se_excite 233 block6f_project_conv 234 block6f_project_bn 235 block6f_drop 236 block6f_add 237 block6g_expand_conv 238 block6g_expand_bn 239 block6g_expand_activation 240 block6g_dwconv2 241 block6g_bn 242 block6g_activation 243 block6g_se_squeeze 244 block6g_se_reshape 245 block6g_se_reduce 246 block6g_se_expand 247 block6g_se_excite 248 block6g_project_conv 249 block6g_project_bn 250 block6g_drop 251 block6g_add 252 block6h_expand_conv 253 block6h_expand_bn 254 block6h_expand_activation 255 block6h_dwconv2 256 block6h_bn 257 block6h_activation 258 block6h_se_squeeze 259 block6h_se_reshape 260 block6h_se_reduce 261 block6h_se_expand 262 block6h_se_excite 263 block6h_project_conv 264 block6h_project_bn 265 block6h_drop 266 block6h_add 267 top_conv 268 top_bn 269 top_activation
It's a lot of layers and would take forever to code it in, yet we can take advantage of transfer learning.
What about a summary of the base model?
base_model.summary()
Model: "efficientnetv2-b0"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer_1 │ (None, None, │ 0 │ - │ │ (InputLayer) │ None, 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ rescaling_1 │ (None, None, │ 0 │ input_layer_1[0]… │ │ (Rescaling) │ None, 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ normalization_1 │ (None, None, │ 0 │ rescaling_1[0][0] │ │ (Normalization) │ None, 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_conv (Conv2D) │ (None, None, │ 864 │ normalization_1[… │ │ │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_bn │ (None, None, │ 128 │ stem_conv[0][0] │ │ (BatchNormalizatio… │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ stem_activation │ (None, None, │ 0 │ stem_bn[0][0] │ │ (Activation) │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_co… │ (None, None, │ 4,608 │ stem_activation[… │ │ (Conv2D) │ None, 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_bn │ (None, None, │ 64 │ block1a_project_… │ │ (BatchNormalizatio… │ None, 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block1a_project_ac… │ (None, None, │ 0 │ block1a_project_… │ │ (Activation) │ None, 16) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_conv │ (None, None, │ 9,216 │ block1a_project_… │ │ (Conv2D) │ None, 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_bn │ (None, None, │ 256 │ block2a_expand_c… │ │ (BatchNormalizatio… │ None, 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_expand_act… │ (None, None, │ 0 │ block2a_expand_b… │ │ (Activation) │ None, 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_co… │ (None, None, │ 2,048 │ block2a_expand_a… │ │ (Conv2D) │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2a_project_bn │ (None, None, │ 128 │ block2a_project_… │ │ (BatchNormalizatio… │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_conv │ (None, None, │ 36,864 │ block2a_project_… │ │ (Conv2D) │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_bn │ (None, None, │ 512 │ block2b_expand_c… │ │ (BatchNormalizatio… │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_expand_act… │ (None, None, │ 0 │ block2b_expand_b… │ │ (Activation) │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_co… │ (None, None, │ 4,096 │ block2b_expand_a… │ │ (Conv2D) │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_project_bn │ (None, None, │ 128 │ block2b_project_… │ │ (BatchNormalizatio… │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_drop │ (None, None, │ 0 │ block2b_project_… │ │ (Dropout) │ None, 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block2b_add (Add) │ (None, None, │ 0 │ block2b_drop[0][… │ │ │ None, 32) │ │ block2a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_conv │ (None, None, │ 36,864 │ block2b_add[0][0] │ │ (Conv2D) │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_bn │ (None, None, │ 512 │ block3a_expand_c… │ │ (BatchNormalizatio… │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_expand_act… │ (None, None, │ 0 │ block3a_expand_b… │ │ (Activation) │ None, 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_co… │ (None, None, │ 6,144 │ block3a_expand_a… │ │ (Conv2D) │ None, 48) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3a_project_bn │ (None, None, │ 192 │ block3a_project_… │ │ (BatchNormalizatio… │ None, 48) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_conv │ (None, None, │ 82,944 │ block3a_project_… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_bn │ (None, None, │ 768 │ block3b_expand_c… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_expand_act… │ (None, None, │ 0 │ block3b_expand_b… │ │ (Activation) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_co… │ (None, None, │ 9,216 │ block3b_expand_a… │ │ (Conv2D) │ None, 48) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_project_bn │ (None, None, │ 192 │ block3b_project_… │ │ (BatchNormalizatio… │ None, 48) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_drop │ (None, None, │ 0 │ block3b_project_… │ │ (Dropout) │ None, 48) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block3b_add (Add) │ (None, None, │ 0 │ block3b_drop[0][… │ │ │ None, 48) │ │ block3a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_conv │ (None, None, │ 9,216 │ block3b_add[0][0] │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_bn │ (None, None, │ 768 │ block4a_expand_c… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_expand_act… │ (None, None, │ 0 │ block4a_expand_b… │ │ (Activation) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_dwconv2 │ (None, None, │ 1,728 │ block4a_expand_a… │ │ (DepthwiseConv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_bn │ (None, None, │ 768 │ block4a_dwconv2[… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_activation │ (None, None, │ 0 │ block4a_bn[0][0] │ │ (Activation) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_squeeze │ (None, 192) │ 0 │ block4a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reshape │ (None, 1, 1, 192) │ 0 │ block4a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_reduce │ (None, 1, 1, 12) │ 2,316 │ block4a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_expand │ (None, 1, 1, 192) │ 2,496 │ block4a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_se_excite │ (None, None, │ 0 │ block4a_activati… │ │ (Multiply) │ None, 192) │ │ block4a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_co… │ (None, None, │ 18,432 │ block4a_se_excit… │ │ (Conv2D) │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4a_project_bn │ (None, None, │ 384 │ block4a_project_… │ │ (BatchNormalizatio… │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_conv │ (None, None, │ 36,864 │ block4a_project_… │ │ (Conv2D) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_bn │ (None, None, │ 1,536 │ block4b_expand_c… │ │ (BatchNormalizatio… │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_expand_act… │ (None, None, │ 0 │ block4b_expand_b… │ │ (Activation) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_dwconv2 │ (None, None, │ 3,456 │ block4b_expand_a… │ │ (DepthwiseConv2D) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_bn │ (None, None, │ 1,536 │ block4b_dwconv2[… │ │ (BatchNormalizatio… │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_activation │ (None, None, │ 0 │ block4b_bn[0][0] │ │ (Activation) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_squeeze │ (None, 384) │ 0 │ block4b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reshape │ (None, 1, 1, 384) │ 0 │ block4b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_reduce │ (None, 1, 1, 24) │ 9,240 │ block4b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_expand │ (None, 1, 1, 384) │ 9,600 │ block4b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_se_excite │ (None, None, │ 0 │ block4b_activati… │ │ (Multiply) │ None, 384) │ │ block4b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_co… │ (None, None, │ 36,864 │ block4b_se_excit… │ │ (Conv2D) │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_project_bn │ (None, None, │ 384 │ block4b_project_… │ │ (BatchNormalizatio… │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_drop │ (None, None, │ 0 │ block4b_project_… │ │ (Dropout) │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4b_add (Add) │ (None, None, │ 0 │ block4b_drop[0][… │ │ │ None, 96) │ │ block4a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_conv │ (None, None, │ 36,864 │ block4b_add[0][0] │ │ (Conv2D) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_bn │ (None, None, │ 1,536 │ block4c_expand_c… │ │ (BatchNormalizatio… │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_expand_act… │ (None, None, │ 0 │ block4c_expand_b… │ │ (Activation) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_dwconv2 │ (None, None, │ 3,456 │ block4c_expand_a… │ │ (DepthwiseConv2D) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_bn │ (None, None, │ 1,536 │ block4c_dwconv2[… │ │ (BatchNormalizatio… │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_activation │ (None, None, │ 0 │ block4c_bn[0][0] │ │ (Activation) │ None, 384) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_squeeze │ (None, 384) │ 0 │ block4c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reshape │ (None, 1, 1, 384) │ 0 │ block4c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_reduce │ (None, 1, 1, 24) │ 9,240 │ block4c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_expand │ (None, 1, 1, 384) │ 9,600 │ block4c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_se_excite │ (None, None, │ 0 │ block4c_activati… │ │ (Multiply) │ None, 384) │ │ block4c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_co… │ (None, None, │ 36,864 │ block4c_se_excit… │ │ (Conv2D) │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_project_bn │ (None, None, │ 384 │ block4c_project_… │ │ (BatchNormalizatio… │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_drop │ (None, None, │ 0 │ block4c_project_… │ │ (Dropout) │ None, 96) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block4c_add (Add) │ (None, None, │ 0 │ block4c_drop[0][… │ │ │ None, 96) │ │ block4b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_conv │ (None, None, │ 55,296 │ block4c_add[0][0] │ │ (Conv2D) │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_bn │ (None, None, │ 2,304 │ block5a_expand_c… │ │ (BatchNormalizatio… │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_expand_act… │ (None, None, │ 0 │ block5a_expand_b… │ │ (Activation) │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_dwconv2 │ (None, None, │ 5,184 │ block5a_expand_a… │ │ (DepthwiseConv2D) │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_bn │ (None, None, │ 2,304 │ block5a_dwconv2[… │ │ (BatchNormalizatio… │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_activation │ (None, None, │ 0 │ block5a_bn[0][0] │ │ (Activation) │ None, 576) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_squeeze │ (None, 576) │ 0 │ block5a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reshape │ (None, 1, 1, 576) │ 0 │ block5a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_reduce │ (None, 1, 1, 24) │ 13,848 │ block5a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_expand │ (None, 1, 1, 576) │ 14,400 │ block5a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_se_excite │ (None, None, │ 0 │ block5a_activati… │ │ (Multiply) │ None, 576) │ │ block5a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_co… │ (None, None, │ 64,512 │ block5a_se_excit… │ │ (Conv2D) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5a_project_bn │ (None, None, │ 448 │ block5a_project_… │ │ (BatchNormalizatio… │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_conv │ (None, None, │ 75,264 │ block5a_project_… │ │ (Conv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_bn │ (None, None, │ 2,688 │ block5b_expand_c… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_expand_act… │ (None, None, │ 0 │ block5b_expand_b… │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_dwconv2 │ (None, None, │ 6,048 │ block5b_expand_a… │ │ (DepthwiseConv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_bn │ (None, None, │ 2,688 │ block5b_dwconv2[… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_activation │ (None, None, │ 0 │ block5b_bn[0][0] │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_squeeze │ (None, 672) │ 0 │ block5b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reshape │ (None, 1, 1, 672) │ 0 │ block5b_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5b_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_se_excite │ (None, None, │ 0 │ block5b_activati… │ │ (Multiply) │ None, 672) │ │ block5b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_co… │ (None, None, │ 75,264 │ block5b_se_excit… │ │ (Conv2D) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_project_bn │ (None, None, │ 448 │ block5b_project_… │ │ (BatchNormalizatio… │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_drop │ (None, None, │ 0 │ block5b_project_… │ │ (Dropout) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5b_add (Add) │ (None, None, │ 0 │ block5b_drop[0][… │ │ │ None, 112) │ │ block5a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_conv │ (None, None, │ 75,264 │ block5b_add[0][0] │ │ (Conv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_bn │ (None, None, │ 2,688 │ block5c_expand_c… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_expand_act… │ (None, None, │ 0 │ block5c_expand_b… │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_dwconv2 │ (None, None, │ 6,048 │ block5c_expand_a… │ │ (DepthwiseConv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_bn │ (None, None, │ 2,688 │ block5c_dwconv2[… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_activation │ (None, None, │ 0 │ block5c_bn[0][0] │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_squeeze │ (None, 672) │ 0 │ block5c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reshape │ (None, 1, 1, 672) │ 0 │ block5c_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5c_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_se_excite │ (None, None, │ 0 │ block5c_activati… │ │ (Multiply) │ None, 672) │ │ block5c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_co… │ (None, None, │ 75,264 │ block5c_se_excit… │ │ (Conv2D) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_project_bn │ (None, None, │ 448 │ block5c_project_… │ │ (BatchNormalizatio… │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_drop │ (None, None, │ 0 │ block5c_project_… │ │ (Dropout) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5c_add (Add) │ (None, None, │ 0 │ block5c_drop[0][… │ │ │ None, 112) │ │ block5b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_expand_conv │ (None, None, │ 75,264 │ block5c_add[0][0] │ │ (Conv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_expand_bn │ (None, None, │ 2,688 │ block5d_expand_c… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_expand_act… │ (None, None, │ 0 │ block5d_expand_b… │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_dwconv2 │ (None, None, │ 6,048 │ block5d_expand_a… │ │ (DepthwiseConv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_bn │ (None, None, │ 2,688 │ block5d_dwconv2[… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_activation │ (None, None, │ 0 │ block5d_bn[0][0] │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_se_squeeze │ (None, 672) │ 0 │ block5d_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_se_reshape │ (None, 1, 1, 672) │ 0 │ block5d_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5d_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5d_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_se_excite │ (None, None, │ 0 │ block5d_activati… │ │ (Multiply) │ None, 672) │ │ block5d_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_project_co… │ (None, None, │ 75,264 │ block5d_se_excit… │ │ (Conv2D) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_project_bn │ (None, None, │ 448 │ block5d_project_… │ │ (BatchNormalizatio… │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_drop │ (None, None, │ 0 │ block5d_project_… │ │ (Dropout) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5d_add (Add) │ (None, None, │ 0 │ block5d_drop[0][… │ │ │ None, 112) │ │ block5c_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_expand_conv │ (None, None, │ 75,264 │ block5d_add[0][0] │ │ (Conv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_expand_bn │ (None, None, │ 2,688 │ block5e_expand_c… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_expand_act… │ (None, None, │ 0 │ block5e_expand_b… │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_dwconv2 │ (None, None, │ 6,048 │ block5e_expand_a… │ │ (DepthwiseConv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_bn │ (None, None, │ 2,688 │ block5e_dwconv2[… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_activation │ (None, None, │ 0 │ block5e_bn[0][0] │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_se_squeeze │ (None, 672) │ 0 │ block5e_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_se_reshape │ (None, 1, 1, 672) │ 0 │ block5e_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block5e_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_se_expand │ (None, 1, 1, 672) │ 19,488 │ block5e_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_se_excite │ (None, None, │ 0 │ block5e_activati… │ │ (Multiply) │ None, 672) │ │ block5e_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_project_co… │ (None, None, │ 75,264 │ block5e_se_excit… │ │ (Conv2D) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_project_bn │ (None, None, │ 448 │ block5e_project_… │ │ (BatchNormalizatio… │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_drop │ (None, None, │ 0 │ block5e_project_… │ │ (Dropout) │ None, 112) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block5e_add (Add) │ (None, None, │ 0 │ block5e_drop[0][… │ │ │ None, 112) │ │ block5d_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_conv │ (None, None, │ 75,264 │ block5e_add[0][0] │ │ (Conv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_bn │ (None, None, │ 2,688 │ block6a_expand_c… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_expand_act… │ (None, None, │ 0 │ block6a_expand_b… │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_dwconv2 │ (None, None, │ 6,048 │ block6a_expand_a… │ │ (DepthwiseConv2D) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_bn │ (None, None, │ 2,688 │ block6a_dwconv2[… │ │ (BatchNormalizatio… │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_activation │ (None, None, │ 0 │ block6a_bn[0][0] │ │ (Activation) │ None, 672) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_squeeze │ (None, 672) │ 0 │ block6a_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reshape │ (None, 1, 1, 672) │ 0 │ block6a_se_squee… │ │ (Reshape) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_reduce │ (None, 1, 1, 28) │ 18,844 │ block6a_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_expand │ (None, 1, 1, 672) │ 19,488 │ block6a_se_reduc… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_se_excite │ (None, None, │ 0 │ block6a_activati… │ │ (Multiply) │ None, 672) │ │ block6a_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_co… │ (None, None, │ 129,024 │ block6a_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6a_project_bn │ (None, None, │ 768 │ block6a_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_conv │ (None, None, │ 221,184 │ block6a_project_… │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_bn │ (None, None, │ 4,608 │ block6b_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_expand_act… │ (None, None, │ 0 │ block6b_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_dwconv2 │ (None, None, │ 10,368 │ block6b_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_bn │ (None, None, │ 4,608 │ block6b_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_activation │ (None, None, │ 0 │ block6b_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_squeeze │ (None, 1152) │ 0 │ block6b_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reshape │ (None, 1, 1, │ 0 │ block6b_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6b_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_expand │ (None, 1, 1, │ 56,448 │ block6b_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_se_excite │ (None, None, │ 0 │ block6b_activati… │ │ (Multiply) │ None, 1152) │ │ block6b_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_co… │ (None, None, │ 221,184 │ block6b_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_project_bn │ (None, None, │ 768 │ block6b_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_drop │ (None, None, │ 0 │ block6b_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6b_add (Add) │ (None, None, │ 0 │ block6b_drop[0][… │ │ │ None, 192) │ │ block6a_project_… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_conv │ (None, None, │ 221,184 │ block6b_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_bn │ (None, None, │ 4,608 │ block6c_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_expand_act… │ (None, None, │ 0 │ block6c_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_dwconv2 │ (None, None, │ 10,368 │ block6c_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_bn │ (None, None, │ 4,608 │ block6c_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_activation │ (None, None, │ 0 │ block6c_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_squeeze │ (None, 1152) │ 0 │ block6c_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reshape │ (None, 1, 1, │ 0 │ block6c_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6c_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_expand │ (None, 1, 1, │ 56,448 │ block6c_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_se_excite │ (None, None, │ 0 │ block6c_activati… │ │ (Multiply) │ None, 1152) │ │ block6c_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_co… │ (None, None, │ 221,184 │ block6c_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_project_bn │ (None, None, │ 768 │ block6c_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_drop │ (None, None, │ 0 │ block6c_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6c_add (Add) │ (None, None, │ 0 │ block6c_drop[0][… │ │ │ None, 192) │ │ block6b_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_conv │ (None, None, │ 221,184 │ block6c_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_bn │ (None, None, │ 4,608 │ block6d_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_expand_act… │ (None, None, │ 0 │ block6d_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_dwconv2 │ (None, None, │ 10,368 │ block6d_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_bn │ (None, None, │ 4,608 │ block6d_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_activation │ (None, None, │ 0 │ block6d_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_squeeze │ (None, 1152) │ 0 │ block6d_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reshape │ (None, 1, 1, │ 0 │ block6d_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6d_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_expand │ (None, 1, 1, │ 56,448 │ block6d_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_se_excite │ (None, None, │ 0 │ block6d_activati… │ │ (Multiply) │ None, 1152) │ │ block6d_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_co… │ (None, None, │ 221,184 │ block6d_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_project_bn │ (None, None, │ 768 │ block6d_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_drop │ (None, None, │ 0 │ block6d_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6d_add (Add) │ (None, None, │ 0 │ block6d_drop[0][… │ │ │ None, 192) │ │ block6c_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_expand_conv │ (None, None, │ 221,184 │ block6d_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_expand_bn │ (None, None, │ 4,608 │ block6e_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_expand_act… │ (None, None, │ 0 │ block6e_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_dwconv2 │ (None, None, │ 10,368 │ block6e_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_bn │ (None, None, │ 4,608 │ block6e_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_activation │ (None, None, │ 0 │ block6e_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_se_squeeze │ (None, 1152) │ 0 │ block6e_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_se_reshape │ (None, 1, 1, │ 0 │ block6e_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6e_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_se_expand │ (None, 1, 1, │ 56,448 │ block6e_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_se_excite │ (None, None, │ 0 │ block6e_activati… │ │ (Multiply) │ None, 1152) │ │ block6e_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_project_co… │ (None, None, │ 221,184 │ block6e_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_project_bn │ (None, None, │ 768 │ block6e_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_drop │ (None, None, │ 0 │ block6e_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6e_add (Add) │ (None, None, │ 0 │ block6e_drop[0][… │ │ │ None, 192) │ │ block6d_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_expand_conv │ (None, None, │ 221,184 │ block6e_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_expand_bn │ (None, None, │ 4,608 │ block6f_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_expand_act… │ (None, None, │ 0 │ block6f_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_dwconv2 │ (None, None, │ 10,368 │ block6f_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_bn │ (None, None, │ 4,608 │ block6f_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_activation │ (None, None, │ 0 │ block6f_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_se_squeeze │ (None, 1152) │ 0 │ block6f_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_se_reshape │ (None, 1, 1, │ 0 │ block6f_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6f_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_se_expand │ (None, 1, 1, │ 56,448 │ block6f_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_se_excite │ (None, None, │ 0 │ block6f_activati… │ │ (Multiply) │ None, 1152) │ │ block6f_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_project_co… │ (None, None, │ 221,184 │ block6f_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_project_bn │ (None, None, │ 768 │ block6f_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_drop │ (None, None, │ 0 │ block6f_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6f_add (Add) │ (None, None, │ 0 │ block6f_drop[0][… │ │ │ None, 192) │ │ block6e_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_expand_conv │ (None, None, │ 221,184 │ block6f_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_expand_bn │ (None, None, │ 4,608 │ block6g_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_expand_act… │ (None, None, │ 0 │ block6g_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_dwconv2 │ (None, None, │ 10,368 │ block6g_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_bn │ (None, None, │ 4,608 │ block6g_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_activation │ (None, None, │ 0 │ block6g_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_se_squeeze │ (None, 1152) │ 0 │ block6g_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_se_reshape │ (None, 1, 1, │ 0 │ block6g_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6g_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_se_expand │ (None, 1, 1, │ 56,448 │ block6g_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_se_excite │ (None, None, │ 0 │ block6g_activati… │ │ (Multiply) │ None, 1152) │ │ block6g_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_project_co… │ (None, None, │ 221,184 │ block6g_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_project_bn │ (None, None, │ 768 │ block6g_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_drop │ (None, None, │ 0 │ block6g_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6g_add (Add) │ (None, None, │ 0 │ block6g_drop[0][… │ │ │ None, 192) │ │ block6f_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_expand_conv │ (None, None, │ 221,184 │ block6g_add[0][0] │ │ (Conv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_expand_bn │ (None, None, │ 4,608 │ block6h_expand_c… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_expand_act… │ (None, None, │ 0 │ block6h_expand_b… │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_dwconv2 │ (None, None, │ 10,368 │ block6h_expand_a… │ │ (DepthwiseConv2D) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_bn │ (None, None, │ 4,608 │ block6h_dwconv2[… │ │ (BatchNormalizatio… │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_activation │ (None, None, │ 0 │ block6h_bn[0][0] │ │ (Activation) │ None, 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_se_squeeze │ (None, 1152) │ 0 │ block6h_activati… │ │ (GlobalAveragePool… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_se_reshape │ (None, 1, 1, │ 0 │ block6h_se_squee… │ │ (Reshape) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_se_reduce │ (None, 1, 1, 48) │ 55,344 │ block6h_se_resha… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_se_expand │ (None, 1, 1, │ 56,448 │ block6h_se_reduc… │ │ (Conv2D) │ 1152) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_se_excite │ (None, None, │ 0 │ block6h_activati… │ │ (Multiply) │ None, 1152) │ │ block6h_se_expan… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_project_co… │ (None, None, │ 221,184 │ block6h_se_excit… │ │ (Conv2D) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_project_bn │ (None, None, │ 768 │ block6h_project_… │ │ (BatchNormalizatio… │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_drop │ (None, None, │ 0 │ block6h_project_… │ │ (Dropout) │ None, 192) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ block6h_add (Add) │ (None, None, │ 0 │ block6h_drop[0][… │ │ │ None, 192) │ │ block6g_add[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_conv (Conv2D) │ (None, None, │ 245,760 │ block6h_add[0][0] │ │ │ None, 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_bn │ (None, None, │ 5,120 │ top_conv[0][0] │ │ (BatchNormalizatio… │ None, 1280) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ top_activation │ (None, None, │ 0 │ top_bn[0][0] │ │ (Activation) │ None, 1280) │ │ │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 5,919,312 (22.58 MB)
Trainable params: 0 (0.00 B)
Non-trainable params: 5,919,312 (22.58 MB)
We can see each different layers and their number of parameters. Since this model is pre-trained, you can think of all of these parameter as patterns the base model has learned on another dataset.
Due to setting base_model.trainable = False, these patterns remain as they are during training (they're frozen and don't get updated).
Alright, let's see the summary of overall model.
# Check summary of model constructed with Functional API
model_0.summary()
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ efficientnetv2-b0 (Functional) │ (None, 7, 7, 1280) │ 5,919,312 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling_layer │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output_layer (Dense) │ (None, 10) │ 12,810 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 5,957,744 (22.73 MB)
Trainable params: 12,810 (50.04 KB)
Non-trainable params: 5,919,312 (22.58 MB)
Optimizer params: 25,622 (100.09 KB)
It looks like our model has 4 layers, but in reality, it has 269 layers within efficientnetv2-b0.
You can see the transformation of shape from input of (None, 224, 224, 3), to the output of (None, 10), where None is placeholder for batch size.
The only trainable parameters in the model are the output layer.
So how does model training curve look like?
# check out model training curve history
plot_loss_curves(history_10_percent)
Getting a feature vector from a trained model¶
What is feature vector? > It's the second to last layer of the model, aka the layer before the output layer. It's a 1D compact representation of the inputted data (e.g. (None, 2048)), before it's classified in the classes of the output layer
Question: What happens with the
tf.keras.GlobalAveragePooling2D()layer transforms a 4D tensor into a 2D tensor by averaging the values across the inner-axes.
The previous sentence is a bit mouthful, so we'll see this in action.
# Define input tensor shape (same number of dimensions as the output of efficientnetv2-b0)
input_shape = (1,4,4,3)
# Create a random tensor
tf.random.set_seed(35)
input_tensor = tf.random.normal(input_shape)
print(f"Random input tensor:\n {input_tensor}\n")
# Pass the random tensor through a global average pooling 2D layer
global_average_pooled_tensor = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
print(f'2D global average pooled random tensor:\n {global_average_pooled_tensor}\n')
# Check the shapes of the different tensors
print(f'Shape of input tensor: {input_tensor.shape}')
print(f'Shape of 2D global averaged pooled input tensor: {global_average_pooled_tensor.shape}')
Random input tensor: [[[[ 2.2217767 -0.30146012 1.0177914 ] [-1.2843658 -0.16923259 -0.7082916 ] [ 0.5628517 -0.9604295 -1.1108632 ] [-0.7519864 0.8463829 0.745947 ]] [[-0.24670422 -0.9917532 0.5135337 ] [-1.7907373 -0.28143376 -0.01009051] [-0.67895067 -0.23497029 -1.0445431 ] [-0.04763518 1.594667 -0.40636218]] [[-1.9090234 -0.87867856 -0.71947384] [ 0.04788789 0.2768906 1.1405766 ] [ 0.44342762 -0.67578 0.6897585 ] [ 1.1216972 -1.0744088 0.182657 ]] [[-0.6914823 -0.7301613 0.49244884] [-0.39586642 0.89530694 -1.5820867 ] [ 0.62116927 0.13989049 1.1740805 ] [-0.61668336 -0.7115485 -0.98318386]]]] 2D global average pooled random tensor: [[-0.21216401 -0.20354491 -0.03800634]] Shape of input tensor: (1, 4, 4, 3) Shape of 2D global averaged pooled input tensor: (1, 3)
Average pooling is collapsed into 3, which don't look correct on the display of our tensor shape. What's happening is each [] represents a pixel, and the 3 values in these pixels represent an RGB colour channel.
Therefore R is 1st column, G is 2nd column, and B is 3rd column of the tensor shape.
As we can see, tf.keras.layers.GlobalAveragePooling2D() layer condensed the input tensor from shape (1,4,4,3) to (1,3). It's done so through averaging input_tensor across the (4,4) shape.
This can be replicated using tf.reduce_mean() operation and specifying the appropriate axes.
# This is the same as GlobalAveragingPooling2D()
tf.reduce_mean(input_tensor, axis=[1,2]) # average across the middle axes of (4,4) > 1,2 indicating the index of our tensor shape
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-0.21216401, -0.20354491, -0.03800634]], dtype=float32)>
Doing that, not only makes the output of base model compatible with the input shape requirement of our output layer (tf.keras.layers.Dense()), it also condenses the information found by the base model into a lower dimension feature vector.
Aka, average pooling cuts out all the extra work in between, and delivers a 1 dimensional dataset, aka a feature vector, which can be directly used by the output layer for classification purposes.
Note: One of the reasons 'feature extraction transfer learning' is named the way it is, is due to what happens when pretrained model outputs a feature vector, that can be used to extract patterns out of.
Practice: Do the same as above but with
tf.keras.layers.GlobalMaxPooling2D()
global_max_pooling_tensor = tf.keras.layers.GlobalMaxPooling2D()(input_tensor)
print(f'2D global max pooled random tensor:\n {global_max_pooling_tensor}')
2D global max pooled random tensor: [[2.2217767 1.594667 1.1740805]]
Running a series of transfer learning experiments¶
We've seen incredible results with only 10% of training data. But what about 1%?
We will answer those questions through running the following experiments:
Model 1: Use feature extraction transfer learning on 1% of the training data with data augmentation.
Model 2: Use feature extraction transfer learning on 10% of the training data with data augmentation and save the results to a checkpoint.
Model 3: Fine-tune the Model 2 checkpoint on 10% of the training data with data augmentation.
Model 4: Fine-tune the Model 2 checkpoint on 100% of the training data with data augmentation.
All these experiments will run on different versions of the training data, but will all be evaluated with the same testing data, so that way we can compare and contrast.
All experiments are to be done with EfficientNetV2B0 model.
We'll make sure to use create_tensorboard_callback() to log and track our experiments.
We'll construct eachmodel using the Keras Functional API. Instead of implementing data augmentation on ImageDataGenerator class, we'll build it into model using tf.keras.layers module.
Let's begin by downloading data for experiment 1, using feature extraction transfer learning on 1% of the training data with data augmentation.
# create training and test dirs
train_dir_1_percent = '10_food_classes_1_percent/train/'
test_dir = '10_food_classes_10_percent/test/' # pull the 10% test folder instead, since both dataset uses the same test files regardless
Let's look at the number of images we're working with
# walk through 1 percent data directory and list number of files
walk_through_dir('10_food_classes_1_percent'), walk_through_dir('10_food_classes_10_percent/test')
There are 1 directories and 0 images in '10_food_classes_1_percent'. There are 10 directories and 0 images in '10_food_classes_1_percent\train'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\ice_cream'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\chicken_curry'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\steak'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\sushi'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\chicken_wings'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\grilled_salmon'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\hamburger'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\pizza'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\ramen'. There are 0 directories and 7 images in '10_food_classes_1_percent\train\fried_rice'. There are 10 directories and 0 images in '10_food_classes_10_percent/test'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\ice_cream'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\chicken_curry'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\steak'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\sushi'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\chicken_wings'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\hamburger'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\pizza'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\ramen'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\fried_rice'.
(None, None)
We have 7 images per training class, making it a challenge for our models.
Time to load our images in as tf.data.Dataset objects. To do so, we use image_dataset_from_directory() method.
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_1_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_1_percent,
label_mode='categorical',
batch_size=32, # default
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode='categorical',
image_size=IMG_SIZE)
Found 70 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Adding data augmentation right into the model¶
Previously we've used different parameters of the ImageDataGenerator class to augment our training images. This time we'll build data augmentation right into the model.
How?
Using tf.keras.layers module and creating a dedicated data augmentation layer.
Augmentation has the following benefits:
- Preprocessing of images (augmenting them) happens on the GPU, rather than CPU. Text and structured data are more suiteed to be preprocessed on CPU.
- Image augmentation only happens during training, so we can still export the whole model and use it elsewhere. If someone wants to train on the same model with the same augmentation, they can.
from IPython.display import Image, display
display(Image(filename='68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f6d7264626f75726b652f74656e736f72666c6f772d646565702d6c6561726e696e672f6d61696e2f696d616765732f30352d646174612d6175676d656e746174696f6e2d696e736964652d612d6d6f64656c2e706e67.png'))
Example use of data augmentation as first layer within model (EfficientNetB0).
Resource: More info on different methods of data augmentation, check out TensorFlow data augmentation guide.
To use data augmentation right within our model we'll create a Keras Sequential model consisting of only data preprocessing layers, then this can be used within another model.
Data augmentation transformations we're gonna use:
tf.keras.layers.RandomFlip- flips image on horizontal or vertical axis.tf.keras.layers.RandomRotation- randomly rotates image by specified amounts.tf.keras.layers.RandomZoom- randomly zooms into an image by specified amounts.tf.keras.layers.RandomHeight- randomly shifts image height by specified amounts.tf.keras.layers.RandomWidth- randomly shifts image width by specified amounts.tf.keras.layers.Rescaling- normalizes image pixel values between 0 and 1. (can be required for some image models, as some don't auto implement this method)
There are more options, but this will be what we'll focus on for now.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# from tensorflow.keras.layers.experimental import preprocessing
# Newer tensorflow versions (2.10+) can use the tensorflow.keras.layers API directly for data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNetV2B0
], name='data_augmentation')
Our data augmentationi sequential model has been set up! We'll be able to slot in the 'model' as a layer in our transfered learning model.
Before that's done, let's test it out by passing random images through it.
# view a random image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import random
target_class = random.choice(train_data_1_percent.class_names) # chooses random class
target_dir = '10_food_classes_1_percent/train/' + target_class # create target directory
random_image = random.choice(os.listdir(target_dir)) # choose a random image from target directory
random_image_path = target_dir + '/' + random_image # create the chosen random image path
img = mpimg.imread(random_image_path) # read in the chosen target image
plt.imshow(img) # plot the target image
plt.title(f'Original random image from class: {target_class}')
plt.axis(False); # turn off the axes
# augment the image
augmented_img = data_augmentation(tf.expand_dims(img, axis=0)) # data augmentation model requires shape (None, height, width, 3)
plt.figure()
plt.imshow(tf.squeeze(augmented_img)/255.) # requires normalization after augmentation
plt.title(f'Augmented random image from class: {target_class}')
plt.axis(False);
This gives you an idea of how the image is augmented when ran a few times. This'll apply to each training images when passed through the pretrained model.
This gives more variety for model and reflect real-life when taking photos. Not all pictures will be perfect, so this'll help the model identify more variations of the same class.
Let's build a model with the Functional API. We'll run through similar steps as before, except for a difference > we'll add data augmentation model asa layer, immediately after the input layer.
Model 1: Feature extraction transfer learning on 1% of the data with data augmentation¶
# setup input shape and base model, freezing the base model layers
input_shape = (224,224,3)
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False, weights='imagenet', input_shape=input_shape)
base_model.trainable = False
# create input layer
inputs = layers.Input(shape=input_shape, name='input_layer')
# add in data augmentation sequential model as a layer
x = data_augmentation(inputs)
# give base_model inputs (after augmentation) and don't train it
x = base_model(x, training=False)
# pool output features of base model
x = layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x)
# put a dense layer on as the output
outputs = layers.Dense(10, activation='softmax', name='output_layer')(x)
# make a model with inputs and outputs
model_1 = keras.Model(inputs, outputs)
# compile the model
model_1.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
# fit the model
history_1_percent = model_1.fit(train_data_1_percent,
epochs=5,
steps_per_epoch=len(train_data_1_percent),
validation_data=test_data,
validation_steps=int(0.25*len(test_data)), # validate for less steps
# track model training logs
callbacks=[create_tensorboard_callback('transfer_learning', '1_percent_data_aug')])
Saving TensorBoard log files to: transfer_learning/1_percent_data_aug/20251006-204956 Epoch 1/5 3/3 ━━━━━━━━━━━━━━━━━━━━ 20s 5s/step - accuracy: 0.1571 - loss: 2.3574 - val_accuracy: 0.1595 - val_loss: 2.2266 Epoch 2/5 3/3 ━━━━━━━━━━━━━━━━━━━━ 7s 3s/step - accuracy: 0.2429 - loss: 2.1532 - val_accuracy: 0.2467 - val_loss: 2.1068 Epoch 3/5 3/3 ━━━━━━━━━━━━━━━━━━━━ 7s 3s/step - accuracy: 0.2857 - loss: 2.0147 - val_accuracy: 0.3405 - val_loss: 1.9988 Epoch 4/5 3/3 ━━━━━━━━━━━━━━━━━━━━ 7s 3s/step - accuracy: 0.4571 - loss: 1.8651 - val_accuracy: 0.4112 - val_loss: 1.9143 Epoch 5/5 3/3 ━━━━━━━━━━━━━━━━━━━━ 7s 3s/step - accuracy: 0.5571 - loss: 1.7401 - val_accuracy: 0.4655 - val_loss: 1.8379
Surprisingly with just 7 pics in each class, it can maintain 45% accuracy, and almost 40% validation accuracy.
Let's check out summary of model. We should see the data augmentation layer just after input layer as well.
# check out model summary
model_1.summary()
Model: "functional_3"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ data_augmentation (Sequential) │ (None, None, None, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ efficientnetv2-b0 (Functional) │ (None, None, None, │ 5,919,312 │ │ │ 1280) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling_layer │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output_layer (Dense) │ (None, 10) │ 12,810 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 5,957,744 (22.73 MB)
Trainable params: 12,810 (50.04 KB)
Non-trainable params: 5,919,312 (22.58 MB)
Optimizer params: 25,622 (100.09 KB)
Yup the augment layer is there. In case wqe need to save and reload the model, the augmentation layer will come with it.
Important thing to remembert is data augmentation only runs during training. So when evaluating our model (aka predicting image class on test data), data augmentation are turned off.
To see that in action, let's evaluate our model on the test data.
# evaluate on the test data
results_1_percent_data_aug = model_1.evaluate(test_data)
results_1_percent_data_aug
79/79 ━━━━━━━━━━━━━━━━━━━━ 24s 305ms/step - accuracy: 0.4748 - loss: 1.8197
[1.8197177648544312, 0.4747999906539917]
Results may be better or worse when compared to our log, due to how we've only tested the model on 25% of the data by validation_steps=int(0.25 * len(test_data)). This helps speed up our epoch training, but still gives a good idea as to our model's performance.
Let's check out the model's loss curves.
# how does the model fo with a data augmentation layer with 1% of data?
plot_loss_curves(history_1_percent)
Based on observation, the model seems to continued to improve on itself if more epochs are trained on them. However we got more experiments to try out and compare first.
Model 2: Feature extraction transfer learning with 10% of data and data augmentation¶
Since we've done augmentation under 1% of the data, let's try it on 10% of the data.
Question: How do we know what experiments to run?
You won't know what to run. It's all based on experimentation, trial and error, or curiousity driven. Follow your thoughts and curiousity to the end, and try it out. Worst that can happen is the method doesn't work, and at best gives very valuable knowledge.
To practically run numerous experiments, reduce the amount of time between your experiments, especially on training. Run small epoch trained models, with smaller versions of the same dataset. What's important is to see how well it compares with other experiments.
Once the promising model is found, you can then scale it up to the original dataset.
This is what we're doing for Model 1, increasing 1% of training data to 10% training data, and see its performance comparatively.
train_dir_10_percent = '10_food_classes_10_percent/train/'
test_dir = '10_food_classes_10_percent/test/'
# setup data inputs
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir_10_percent,
label_mode='categorical',
image_size=IMG_SIZE)
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode='categorical',
image_size=IMG_SIZE)
Found 750 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
Let's build a model with data augmentation built into it. You can copy paste the augmentation Sequential model, but it's best to manually recreate it.
# create a functional model with data augmentation
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
# new TF versions (2.10+) can use tendorflow.keras.layers API directly for data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip('horizontal'),
layers.RandomRotation(0.2),
layers.RandomZoom(0.2),
layers.RandomHeight(0.2),
layers.RandomWidth(0.2),
# preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNet
], name='data_augmentation')
## OLD
# # Build data augmentation layer
# data_augmentation = Sequential([
# preprocessing.RandomFlip('horizontal'),
# preprocessing.RandomHeight(0.2),
# preprocessing.RandomWidth(0.2),
# preprocessing.RandomZoom(0.2),
# preprocessing.RandomRotation(0.2),
# # preprocessing.Rescaling(1./255) # keep for ResNet50V2, remove for EfficientNet
# ], name="data_augmentation")
# setup the input shape to our model
input_shape = (224,224,3)
# create a frozen base model
# base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
base_model.trainable=False
# create input and output layers
inputs = layers.Input(shape=input_shape, name='input_layer') # create input layer
x = data_augmentation(inputs) # augment our training images
x = base_model(x, training=False) # pass augmented images to base model but keep it in inference mode, so batchnorm layers don't get updated
x = layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x)
outputs = layers.Dense(10, activation='softmax', name='output_layer')(x)
model_2 = tf.keras.Model(inputs, outputs)
# compile
model_2.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), # use Adam optimizer with base learning rate
metrics=['accuracy'])
To save time for next time, aka wanting to perform multiple experiments with model_2, we can put the above into a function for reuse.
def create_base_model(input_shape: tuple[int,int,int] = (224,224,3),
output_shape: int=10,
learning_rate: float=0.001,
training: bool=False) -> tf.keras.Model:
'''
Create a model based on EfficientNetV2B0 with built-in data augmentation.
Parameters:
- input_shape (tuple): Expected shape of input images. Default is (224,224,3).
- output_shape (int): Number of classes for the output layer. Default is 10
- learning_rate (float): Learning rate for the Adam optimizer. Default is 0.001.
- training (bool): Whether the base model is trainable. Default is False
Returns:
- tf.keras.Model: The compiled model with specified input and output settings.
'''
# create base model
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(include_top=False)
base_model.trainable = training
# Setup model input and outputswith data augmentation built-in
inputs = layers.Input(shape=input_shape, name='input_layer')
x = data_augmentation(inputs)
x = base_model(x, training=False) # pass augmented images to base model but keep it in inference mode
x = layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x)
outputs = layers.Dense(units=output_shape, activation='softmax', name='output_layer')(x)
model = tf.keras.Model(inputs, outputs)
# compile model
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
metrics=['accuracy'])
return model
# create an instance of model_2 with our new function
model_2 = create_base_model()
Creating a ModelCheckpoint callback¶
Before fitting model to training data, we have this thing called ModelCheckpoint callback.
The ModelCheckpoint callback gives you the ability to save your model. The Savedmodel format are the weight patterns only, saved into a specified directory while training.
This is especially helpful if training takes too long, and you need to make backups of it during training. This can also mean if you think the model will do better with even longer training period, you can reload it from a specific checkpoint and continue training from there.
E.g. You have a feature extraction model that was trained for 5 epochs. Based on trajectory on its history callback, you think the model will continue to improve further for another 5 epochs. You can then load the checkpoint for the model, and even unfreeze some/all the layers of model and continue training.
And that's what we're gonna do!
Before that, let's create ModelCheckpoint callback, and specify a directory to save it to.
# setup checkpoint path
checkpoint_path = "ten_percent_model_checkpoints_weights/checkpoint.weights.h5" # note: remember saving directly to colab is temporary
# create a ModelCheckpoint callback that saves the model's weights only
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True, # set to False to save the entire model
save_best_only=True, # save only the best model weights instead of a model every epoch
save_freq='epoch', # save every epoch
verbose=1)
Questions: What's the difference between saving the entire model, and saving it's weights only?
The SavedModel format saves the model architecture, weights, and training config all in one. It makes it very easy to roload the model elsewhere. With weights itself, you can share it if you don't want all the details of the model to be shared with someone else, or want to save disk space and save faster.
We'll also add in our checkpoint_callback in our list of callbacks.
# fit the model saving checkpoints every epoch
initial_epochs=5
history_10_percent_data_aug = model_2.fit(train_data_10_percent,
epochs=initial_epochs,
validation_data=test_data,
validation_steps=int(0.25 * len(test_data)), # do less steps per validation (quicker)
callbacks=[create_tensorboard_callback('transfer_learning', '10_percent_data_aug'),
checkpoint_callback])
Saving TensorBoard log files to: transfer_learning/10_percent_data_aug/20251006-205113 Epoch 1/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 377ms/step - accuracy: 0.1913 - loss: 2.2321 Epoch 1: val_loss improved from None to 1.55309, saving model to ten_percent_model_checkpoints_weights/checkpoint.weights.h5 24/24 ━━━━━━━━━━━━━━━━━━━━ 28s 781ms/step - accuracy: 0.3120 - loss: 2.0551 - val_accuracy: 0.6727 - val_loss: 1.5531 Epoch 2/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 405ms/step - accuracy: 0.5755 - loss: 1.5684 Epoch 2: val_loss improved from 1.55309 to 1.11007, saving model to ten_percent_model_checkpoints_weights/checkpoint.weights.h5 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 695ms/step - accuracy: 0.6080 - loss: 1.4725 - val_accuracy: 0.7730 - val_loss: 1.1101 Epoch 3/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 370ms/step - accuracy: 0.7100 - loss: 1.2084 Epoch 3: val_loss improved from 1.11007 to 0.89955, saving model to ten_percent_model_checkpoints_weights/checkpoint.weights.h5 24/24 ━━━━━━━━━━━━━━━━━━━━ 15s 650ms/step - accuracy: 0.7187 - loss: 1.1685 - val_accuracy: 0.8059 - val_loss: 0.8996 Epoch 4/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 377ms/step - accuracy: 0.7401 - loss: 1.0095 Epoch 4: val_loss improved from 0.89955 to 0.78334, saving model to ten_percent_model_checkpoints_weights/checkpoint.weights.h5 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 657ms/step - accuracy: 0.7307 - loss: 1.0163 - val_accuracy: 0.8043 - val_loss: 0.7833 Epoch 5/5 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 372ms/step - accuracy: 0.7532 - loss: 0.9251 Epoch 5: val_loss improved from 0.78334 to 0.70671, saving model to ten_percent_model_checkpoints_weights/checkpoint.weights.h5 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 650ms/step - accuracy: 0.7627 - loss: 0.9040 - val_accuracy: 0.8289 - val_loss: 0.7067
The ModelCheckpoint callback worked, and saved at every epoch without costing too much time and resources!
Let's evaluate the model and check it's loss curves.
# evaluate on the test data
results_10_percent_data_aug = model_2.evaluate(test_data)
results_10_percent_data_aug
79/79 ━━━━━━━━━━━━━━━━━━━━ 24s 307ms/step - accuracy: 0.8164 - loss: 0.7025
[0.7025159597396851, 0.8163999915122986]
# plot model loss curves
plot_loss_curves(history_10_percent_data_aug)
Let's compare the graph to model_0, where we trained 10% of data, but had no data augmentation.
# check out model training curve history
plot_loss_curves(history_10_percent)
We can see how adding augmentation on the same amount of data, didn't show as quick of an improvement, as if the data has no augmentation.
However caveat to the problem, is that the non-augmented model bottomed and flattened out quickly, as training data accuracy, overtook validation data accuracy. With augmented data, the graph shows as if it has the potential to further improve, with the validation accuracy, comfortably sitting above training accuracy.
So if we trained augmented some more epochs, it may likely beat the non-augmented model.
Since we've checkpointed the model's weights, we can try load it back in, and test it to see if it saved correctly, by evaluating it on the test data.
To load model weight, you can use load_weights() method, passing it the path that saved the model weights.
# load in saved model weights and evaluate model
model_2.load_weights(checkpoint_path)
loaded_weights_model_results = model_2.evaluate(test_data)
79/79 ━━━━━━━━━━━━━━━━━━━━ 24s 302ms/step - accuracy: 0.8164 - loss: 0.7025
Now let's compare the results to previously trained model. They should have very close results, with minor differences only attributed to the precision level/rounding the computer does.
# if the results from our native model and loaded weights are the same, then output should be True
results_10_percent_data_aug == loaded_weights_model_results
True
results_10_percent_data_aug
[0.7025159597396851, 0.8163999915122986]
loaded_weights_model_results
[0.7025159597396851, 0.8163999915122986]
Well it's close enough
import numpy as np
# check to see if loaded model results are very close to native model results (should output True)
np.isclose(np.array(results_10_percent_data_aug), np.array(loaded_weights_model_results))
array([ True, True])
# check the difference between the two results (small values)
print(np.array(results_10_percent_data_aug) - np.array(loaded_weights_model_results))
[0. 0.]
Model 3: Fine-tuning an existing model on 10% of the data¶
display(Image(filename='68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f6d7264626f75726b652f74656e736f72666c6f772d646565702d6c6561726e696e672f6d61696e2f696d616765732f30352d66696e652d74756e696e672d616e2d656666696369656e746e65742d6d6f64656c2e706e67.png'))
High-level example of fine-tuning an EfficientNet model. Bottom layers (the layers next to the input data) stays frozen, where as top layers (the layers next to the output data) are updated bit by bit.
So far our saved model has been trained using feature extraction transfer learning for 5 epochs on 10% of the training data and data augmentation.
This means all of the layers were frozen during training.
For the next experiment, we're going to use fine-tuning transfer learning. We will use the same base model, except unfreezing some of its layers (closest to the top) and running the model for a few more epochs.
Note: Fine-tuning usually works best after training a feature extraction model for a few epochs and with large amounts of data. More info can be checked out here: Keras' guide on Transfer learning & fine-tuning.
Let's check out its layers.
# layers in loaded model
model_2.layers
[<InputLayer name=input_layer, built=True>, <Sequential name=data_augmentation, built=True>, <Functional name=efficientnetv2-b0, built=True>, <GlobalAveragePooling2D name=global_average_pooling_layer, built=True>, <Dense name=output_layer, built=True>]
How about we check their names, numbers and if they're trainable?
for layer_number, layer in enumerate(model_2.layers):
print(f'Layer number: {layer_number} | Layer name: {layer} | Layer type: {layer} | Trainable? {layer.trainable}')
Layer number: 0 | Layer name: <InputLayer name=input_layer, built=True> | Layer type: <InputLayer name=input_layer, built=True> | Trainable? True Layer number: 1 | Layer name: <Sequential name=data_augmentation, built=True> | Layer type: <Sequential name=data_augmentation, built=True> | Trainable? True Layer number: 2 | Layer name: <Functional name=efficientnetv2-b0, built=True> | Layer type: <Functional name=efficientnetv2-b0, built=True> | Trainable? False Layer number: 3 | Layer name: <GlobalAveragePooling2D name=global_average_pooling_layer, built=True> | Layer type: <GlobalAveragePooling2D name=global_average_pooling_layer, built=True> | Trainable? True Layer number: 4 | Layer name: <Dense name=output_layer, built=True> | Layer type: <Dense name=output_layer, built=True> | Trainable? True
We've got an input layer, a Sequential layer (the data augmentation model), a Functional layer (EfficientNetV2B0), a pooling layer and a Dense layer (the output layer).
How about a summary?
model_2.summary()
Model: "functional_6"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ data_augmentation (Sequential) │ (None, None, None, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ efficientnetv2-b0 (Functional) │ (None, None, None, │ 5,919,312 │ │ │ 1280) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling_layer │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output_layer (Dense) │ (None, 10) │ 12,810 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 5,957,744 (22.73 MB)
Trainable params: 12,810 (50.04 KB)
Non-trainable params: 5,919,312 (22.58 MB)
Optimizer params: 25,622 (100.09 KB)
It looks like the layers in efficientnetv2-b0 are frozen. We can confirm this using the trainable_variables attribute.
Note: The layers of
base_model, akaefficientnetv2-b0ofmodel_2is accessible by referencingmodel_2.layers[2].
# access the base_model layers of model_2
model_2_base_model = model_2.layers[2]
model_2_base_model.name
'efficientnetv2-b0'
# How many layers are trainable in our model_2_base_model?
print(len(model_2_base_model.trainable_variables)) # layer at index 2 is the EfficientNetV2B0 layer (the base model)
0
We can even check layer by layer to see if they're trainable.
To access the layers in model_2_base_model, we can use the layers attribute.
# check which layer are tuneable (trainable)
for layer_number, layer in enumerate(model_2_base_model.layers):
print(layer_number, layer.name, layer.trainable)
0 input_layer_9 False 1 rescaling_6 False 2 normalization_6 False 3 stem_conv False 4 stem_bn False 5 stem_activation False 6 block1a_project_conv False 7 block1a_project_bn False 8 block1a_project_activation False 9 block2a_expand_conv False 10 block2a_expand_bn False 11 block2a_expand_activation False 12 block2a_project_conv False 13 block2a_project_bn False 14 block2b_expand_conv False 15 block2b_expand_bn False 16 block2b_expand_activation False 17 block2b_project_conv False 18 block2b_project_bn False 19 block2b_drop False 20 block2b_add False 21 block3a_expand_conv False 22 block3a_expand_bn False 23 block3a_expand_activation False 24 block3a_project_conv False 25 block3a_project_bn False 26 block3b_expand_conv False 27 block3b_expand_bn False 28 block3b_expand_activation False 29 block3b_project_conv False 30 block3b_project_bn False 31 block3b_drop False 32 block3b_add False 33 block4a_expand_conv False 34 block4a_expand_bn False 35 block4a_expand_activation False 36 block4a_dwconv2 False 37 block4a_bn False 38 block4a_activation False 39 block4a_se_squeeze False 40 block4a_se_reshape False 41 block4a_se_reduce False 42 block4a_se_expand False 43 block4a_se_excite False 44 block4a_project_conv False 45 block4a_project_bn False 46 block4b_expand_conv False 47 block4b_expand_bn False 48 block4b_expand_activation False 49 block4b_dwconv2 False 50 block4b_bn False 51 block4b_activation False 52 block4b_se_squeeze False 53 block4b_se_reshape False 54 block4b_se_reduce False 55 block4b_se_expand False 56 block4b_se_excite False 57 block4b_project_conv False 58 block4b_project_bn False 59 block4b_drop False 60 block4b_add False 61 block4c_expand_conv False 62 block4c_expand_bn False 63 block4c_expand_activation False 64 block4c_dwconv2 False 65 block4c_bn False 66 block4c_activation False 67 block4c_se_squeeze False 68 block4c_se_reshape False 69 block4c_se_reduce False 70 block4c_se_expand False 71 block4c_se_excite False 72 block4c_project_conv False 73 block4c_project_bn False 74 block4c_drop False 75 block4c_add False 76 block5a_expand_conv False 77 block5a_expand_bn False 78 block5a_expand_activation False 79 block5a_dwconv2 False 80 block5a_bn False 81 block5a_activation False 82 block5a_se_squeeze False 83 block5a_se_reshape False 84 block5a_se_reduce False 85 block5a_se_expand False 86 block5a_se_excite False 87 block5a_project_conv False 88 block5a_project_bn False 89 block5b_expand_conv False 90 block5b_expand_bn False 91 block5b_expand_activation False 92 block5b_dwconv2 False 93 block5b_bn False 94 block5b_activation False 95 block5b_se_squeeze False 96 block5b_se_reshape False 97 block5b_se_reduce False 98 block5b_se_expand False 99 block5b_se_excite False 100 block5b_project_conv False 101 block5b_project_bn False 102 block5b_drop False 103 block5b_add False 104 block5c_expand_conv False 105 block5c_expand_bn False 106 block5c_expand_activation False 107 block5c_dwconv2 False 108 block5c_bn False 109 block5c_activation False 110 block5c_se_squeeze False 111 block5c_se_reshape False 112 block5c_se_reduce False 113 block5c_se_expand False 114 block5c_se_excite False 115 block5c_project_conv False 116 block5c_project_bn False 117 block5c_drop False 118 block5c_add False 119 block5d_expand_conv False 120 block5d_expand_bn False 121 block5d_expand_activation False 122 block5d_dwconv2 False 123 block5d_bn False 124 block5d_activation False 125 block5d_se_squeeze False 126 block5d_se_reshape False 127 block5d_se_reduce False 128 block5d_se_expand False 129 block5d_se_excite False 130 block5d_project_conv False 131 block5d_project_bn False 132 block5d_drop False 133 block5d_add False 134 block5e_expand_conv False 135 block5e_expand_bn False 136 block5e_expand_activation False 137 block5e_dwconv2 False 138 block5e_bn False 139 block5e_activation False 140 block5e_se_squeeze False 141 block5e_se_reshape False 142 block5e_se_reduce False 143 block5e_se_expand False 144 block5e_se_excite False 145 block5e_project_conv False 146 block5e_project_bn False 147 block5e_drop False 148 block5e_add False 149 block6a_expand_conv False 150 block6a_expand_bn False 151 block6a_expand_activation False 152 block6a_dwconv2 False 153 block6a_bn False 154 block6a_activation False 155 block6a_se_squeeze False 156 block6a_se_reshape False 157 block6a_se_reduce False 158 block6a_se_expand False 159 block6a_se_excite False 160 block6a_project_conv False 161 block6a_project_bn False 162 block6b_expand_conv False 163 block6b_expand_bn False 164 block6b_expand_activation False 165 block6b_dwconv2 False 166 block6b_bn False 167 block6b_activation False 168 block6b_se_squeeze False 169 block6b_se_reshape False 170 block6b_se_reduce False 171 block6b_se_expand False 172 block6b_se_excite False 173 block6b_project_conv False 174 block6b_project_bn False 175 block6b_drop False 176 block6b_add False 177 block6c_expand_conv False 178 block6c_expand_bn False 179 block6c_expand_activation False 180 block6c_dwconv2 False 181 block6c_bn False 182 block6c_activation False 183 block6c_se_squeeze False 184 block6c_se_reshape False 185 block6c_se_reduce False 186 block6c_se_expand False 187 block6c_se_excite False 188 block6c_project_conv False 189 block6c_project_bn False 190 block6c_drop False 191 block6c_add False 192 block6d_expand_conv False 193 block6d_expand_bn False 194 block6d_expand_activation False 195 block6d_dwconv2 False 196 block6d_bn False 197 block6d_activation False 198 block6d_se_squeeze False 199 block6d_se_reshape False 200 block6d_se_reduce False 201 block6d_se_expand False 202 block6d_se_excite False 203 block6d_project_conv False 204 block6d_project_bn False 205 block6d_drop False 206 block6d_add False 207 block6e_expand_conv False 208 block6e_expand_bn False 209 block6e_expand_activation False 210 block6e_dwconv2 False 211 block6e_bn False 212 block6e_activation False 213 block6e_se_squeeze False 214 block6e_se_reshape False 215 block6e_se_reduce False 216 block6e_se_expand False 217 block6e_se_excite False 218 block6e_project_conv False 219 block6e_project_bn False 220 block6e_drop False 221 block6e_add False 222 block6f_expand_conv False 223 block6f_expand_bn False 224 block6f_expand_activation False 225 block6f_dwconv2 False 226 block6f_bn False 227 block6f_activation False 228 block6f_se_squeeze False 229 block6f_se_reshape False 230 block6f_se_reduce False 231 block6f_se_expand False 232 block6f_se_excite False 233 block6f_project_conv False 234 block6f_project_bn False 235 block6f_drop False 236 block6f_add False 237 block6g_expand_conv False 238 block6g_expand_bn False 239 block6g_expand_activation False 240 block6g_dwconv2 False 241 block6g_bn False 242 block6g_activation False 243 block6g_se_squeeze False 244 block6g_se_reshape False 245 block6g_se_reduce False 246 block6g_se_expand False 247 block6g_se_excite False 248 block6g_project_conv False 249 block6g_project_bn False 250 block6g_drop False 251 block6g_add False 252 block6h_expand_conv False 253 block6h_expand_bn False 254 block6h_expand_activation False 255 block6h_dwconv2 False 256 block6h_bn False 257 block6h_activation False 258 block6h_se_squeeze False 259 block6h_se_reshape False 260 block6h_se_reduce False 261 block6h_se_expand False 262 block6h_se_excite False 263 block6h_project_conv False 264 block6h_project_bn False 265 block6h_drop False 266 block6h_add False 267 top_conv False 268 top_bn False 269 top_activation False
Sweet, all are frozen.
Now to fine-tune the model to our data, we will unfreeze the top 10 layers and continue training our model for another 5 epochs. Therefore the remaining layers will be frozen, and not updated during training.
Question: How many layers should you unfreeze when training?
There is no set rule to how much you should unfreeze. Can by one by one layers, or unfreezing the entire model. General rule of thumb, is when you have less training data, you want to unfreeze less layers at a time, and gradually fine tune.
Resources: The ULMFiT(Universal Lanuage Model Fine-tuning for Text Classification paper) has a great series of experiments on fine-tuning models.
To begin fine-tuning, we'll unfreeze the entire model_2_base_model by setting its trainable attribute to True.
Then we'll refreeze every layer in model_2_base_model except the last 10 by looping through them and setting their trainable attribute to False.
Finally we'll recompile the whole model.
# Make all the layers in model_2_base_model trainable
model_2_base_model.trainable = True
# freeze all layers except for the last 10
for layer in model_2_base_model.layers[:-10]: # for layers, except the last 10 layers
layer.trainable = False
# recompile the whole model (always recompile after any adjustments to a model)
model_2.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), # lr is 10x lower than before for fine-tuning
metrics=['accuracy'])
Now let's check which layers of the pretrained model are trainable.
# check which layer are tuneable (trainable)
for layer_number, layer in enumerate(model_2_base_model.layers):
print(layer_number, layer.name, layer.trainable)
0 input_layer_9 False 1 rescaling_6 False 2 normalization_6 False 3 stem_conv False 4 stem_bn False 5 stem_activation False 6 block1a_project_conv False 7 block1a_project_bn False 8 block1a_project_activation False 9 block2a_expand_conv False 10 block2a_expand_bn False 11 block2a_expand_activation False 12 block2a_project_conv False 13 block2a_project_bn False 14 block2b_expand_conv False 15 block2b_expand_bn False 16 block2b_expand_activation False 17 block2b_project_conv False 18 block2b_project_bn False 19 block2b_drop False 20 block2b_add False 21 block3a_expand_conv False 22 block3a_expand_bn False 23 block3a_expand_activation False 24 block3a_project_conv False 25 block3a_project_bn False 26 block3b_expand_conv False 27 block3b_expand_bn False 28 block3b_expand_activation False 29 block3b_project_conv False 30 block3b_project_bn False 31 block3b_drop False 32 block3b_add False 33 block4a_expand_conv False 34 block4a_expand_bn False 35 block4a_expand_activation False 36 block4a_dwconv2 False 37 block4a_bn False 38 block4a_activation False 39 block4a_se_squeeze False 40 block4a_se_reshape False 41 block4a_se_reduce False 42 block4a_se_expand False 43 block4a_se_excite False 44 block4a_project_conv False 45 block4a_project_bn False 46 block4b_expand_conv False 47 block4b_expand_bn False 48 block4b_expand_activation False 49 block4b_dwconv2 False 50 block4b_bn False 51 block4b_activation False 52 block4b_se_squeeze False 53 block4b_se_reshape False 54 block4b_se_reduce False 55 block4b_se_expand False 56 block4b_se_excite False 57 block4b_project_conv False 58 block4b_project_bn False 59 block4b_drop False 60 block4b_add False 61 block4c_expand_conv False 62 block4c_expand_bn False 63 block4c_expand_activation False 64 block4c_dwconv2 False 65 block4c_bn False 66 block4c_activation False 67 block4c_se_squeeze False 68 block4c_se_reshape False 69 block4c_se_reduce False 70 block4c_se_expand False 71 block4c_se_excite False 72 block4c_project_conv False 73 block4c_project_bn False 74 block4c_drop False 75 block4c_add False 76 block5a_expand_conv False 77 block5a_expand_bn False 78 block5a_expand_activation False 79 block5a_dwconv2 False 80 block5a_bn False 81 block5a_activation False 82 block5a_se_squeeze False 83 block5a_se_reshape False 84 block5a_se_reduce False 85 block5a_se_expand False 86 block5a_se_excite False 87 block5a_project_conv False 88 block5a_project_bn False 89 block5b_expand_conv False 90 block5b_expand_bn False 91 block5b_expand_activation False 92 block5b_dwconv2 False 93 block5b_bn False 94 block5b_activation False 95 block5b_se_squeeze False 96 block5b_se_reshape False 97 block5b_se_reduce False 98 block5b_se_expand False 99 block5b_se_excite False 100 block5b_project_conv False 101 block5b_project_bn False 102 block5b_drop False 103 block5b_add False 104 block5c_expand_conv False 105 block5c_expand_bn False 106 block5c_expand_activation False 107 block5c_dwconv2 False 108 block5c_bn False 109 block5c_activation False 110 block5c_se_squeeze False 111 block5c_se_reshape False 112 block5c_se_reduce False 113 block5c_se_expand False 114 block5c_se_excite False 115 block5c_project_conv False 116 block5c_project_bn False 117 block5c_drop False 118 block5c_add False 119 block5d_expand_conv False 120 block5d_expand_bn False 121 block5d_expand_activation False 122 block5d_dwconv2 False 123 block5d_bn False 124 block5d_activation False 125 block5d_se_squeeze False 126 block5d_se_reshape False 127 block5d_se_reduce False 128 block5d_se_expand False 129 block5d_se_excite False 130 block5d_project_conv False 131 block5d_project_bn False 132 block5d_drop False 133 block5d_add False 134 block5e_expand_conv False 135 block5e_expand_bn False 136 block5e_expand_activation False 137 block5e_dwconv2 False 138 block5e_bn False 139 block5e_activation False 140 block5e_se_squeeze False 141 block5e_se_reshape False 142 block5e_se_reduce False 143 block5e_se_expand False 144 block5e_se_excite False 145 block5e_project_conv False 146 block5e_project_bn False 147 block5e_drop False 148 block5e_add False 149 block6a_expand_conv False 150 block6a_expand_bn False 151 block6a_expand_activation False 152 block6a_dwconv2 False 153 block6a_bn False 154 block6a_activation False 155 block6a_se_squeeze False 156 block6a_se_reshape False 157 block6a_se_reduce False 158 block6a_se_expand False 159 block6a_se_excite False 160 block6a_project_conv False 161 block6a_project_bn False 162 block6b_expand_conv False 163 block6b_expand_bn False 164 block6b_expand_activation False 165 block6b_dwconv2 False 166 block6b_bn False 167 block6b_activation False 168 block6b_se_squeeze False 169 block6b_se_reshape False 170 block6b_se_reduce False 171 block6b_se_expand False 172 block6b_se_excite False 173 block6b_project_conv False 174 block6b_project_bn False 175 block6b_drop False 176 block6b_add False 177 block6c_expand_conv False 178 block6c_expand_bn False 179 block6c_expand_activation False 180 block6c_dwconv2 False 181 block6c_bn False 182 block6c_activation False 183 block6c_se_squeeze False 184 block6c_se_reshape False 185 block6c_se_reduce False 186 block6c_se_expand False 187 block6c_se_excite False 188 block6c_project_conv False 189 block6c_project_bn False 190 block6c_drop False 191 block6c_add False 192 block6d_expand_conv False 193 block6d_expand_bn False 194 block6d_expand_activation False 195 block6d_dwconv2 False 196 block6d_bn False 197 block6d_activation False 198 block6d_se_squeeze False 199 block6d_se_reshape False 200 block6d_se_reduce False 201 block6d_se_expand False 202 block6d_se_excite False 203 block6d_project_conv False 204 block6d_project_bn False 205 block6d_drop False 206 block6d_add False 207 block6e_expand_conv False 208 block6e_expand_bn False 209 block6e_expand_activation False 210 block6e_dwconv2 False 211 block6e_bn False 212 block6e_activation False 213 block6e_se_squeeze False 214 block6e_se_reshape False 215 block6e_se_reduce False 216 block6e_se_expand False 217 block6e_se_excite False 218 block6e_project_conv False 219 block6e_project_bn False 220 block6e_drop False 221 block6e_add False 222 block6f_expand_conv False 223 block6f_expand_bn False 224 block6f_expand_activation False 225 block6f_dwconv2 False 226 block6f_bn False 227 block6f_activation False 228 block6f_se_squeeze False 229 block6f_se_reshape False 230 block6f_se_reduce False 231 block6f_se_expand False 232 block6f_se_excite False 233 block6f_project_conv False 234 block6f_project_bn False 235 block6f_drop False 236 block6f_add False 237 block6g_expand_conv False 238 block6g_expand_bn False 239 block6g_expand_activation False 240 block6g_dwconv2 False 241 block6g_bn False 242 block6g_activation False 243 block6g_se_squeeze False 244 block6g_se_reshape False 245 block6g_se_reduce False 246 block6g_se_expand False 247 block6g_se_excite False 248 block6g_project_conv False 249 block6g_project_bn False 250 block6g_drop False 251 block6g_add False 252 block6h_expand_conv False 253 block6h_expand_bn False 254 block6h_expand_activation False 255 block6h_dwconv2 False 256 block6h_bn False 257 block6h_activation False 258 block6h_se_squeeze False 259 block6h_se_reshape False 260 block6h_se_reduce True 261 block6h_se_expand True 262 block6h_se_excite True 263 block6h_project_conv True 264 block6h_project_bn True 265 block6h_drop True 266 block6h_add True 267 top_conv True 268 top_bn True 269 top_activation True
Nice the last 10 layers are unfrozened!
Question: Why did we recompile the model?
Every time you make a change to the base model's architecture/trainable parameters, you would need to recompile them.
For our case, our loss, optimizer, and metrics are the same as before. Not the learning rate, which is 10x smaller than before.
We shrink Adam so the model doesn't try overwriting existing weights in the pretrained model too quickly. Basically, we want learning to be more gradual.
Note: There's no set standard for setting the learning rate during fine-tuning, though reductions of 2.6x-10x+ seem to work well in practice.
How many trainable variables do we have now?
print(len(model_2.trainable_variables))
12
It looks we have 12 trainable variables. The 10 layers of base model, in addition to the weight and bias parameters of the Dense layer output.
Time to fine tune :)
From Model_2, we've already done 5 epochs worth of training on it. We will continue where it left off, and train another 5 epochs.
To do this, we can use initial_epoch parameter of the fit() method. We'll pass it the last epoch of the previous model's training history. (history_10_percent_data_aug.epoch[1])
# fine tune for another 5 epochs
fine_tune_epochs = initial_epochs + 5
# refit the model (same as mode_2 but more trainable layers)
history_fine_10_percent_data_aug = model_2.fit(train_data_10_percent,
epochs=fine_tune_epochs,
validation_data=test_data,
initial_epoch=history_10_percent_data_aug.epoch[-1], # start from previous last epoch
validation_steps=int(0.25 * len(test_data)),
callbacks=[create_tensorboard_callback('transfer_learning','10_percent_fine_tune_last_10')]) # name the experiment accordingly
Saving TensorBoard log files to: transfer_learning/10_percent_fine_tune_last_10/20251006-205334 Epoch 5/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 30s 771ms/step - accuracy: 0.7533 - loss: 1.0031 - val_accuracy: 0.8240 - val_loss: 0.7592 Epoch 6/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 15s 656ms/step - accuracy: 0.7800 - loss: 0.9083 - val_accuracy: 0.8487 - val_loss: 0.6563 Epoch 7/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 661ms/step - accuracy: 0.7787 - loss: 0.8607 - val_accuracy: 0.8520 - val_loss: 0.6279 Epoch 8/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 660ms/step - accuracy: 0.8147 - loss: 0.7854 - val_accuracy: 0.8602 - val_loss: 0.5753 Epoch 9/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 17s 703ms/step - accuracy: 0.8013 - loss: 0.7408 - val_accuracy: 0.8569 - val_loss: 0.5507 Epoch 10/10 24/24 ━━━━━━━━━━━━━━━━━━━━ 16s 674ms/step - accuracy: 0.8133 - loss: 0.7238 - val_accuracy: 0.8536 - val_loss: 0.5182
Note: Fine tuning usually takes a lot longer than feature extraction, as model has to update more weights throughout the model.
Our model has done some percentage better now. Let's evaluate it.
# Evaluate the model on the test data
results_fine_tune_10_percent = model_2.evaluate(test_data)
79/79 ━━━━━━━━━━━━━━━━━━━━ 24s 304ms/step - accuracy: 0.8460 - loss: 0.5296
We need a way to evaluate model performance 'before' and 'after' tuning. How about we write a function to compare before and after?
def compare_history(original_history, new_history, initial_epochs=5):
'''
compare two model history objects
'''
# get original history measurements
acc = original_history.history['accuracy']
loss = original_history.history['loss']
print(len(acc))
val_acc = original_history.history['val_accuracy']
val_loss = original_history.history['val_loss']
# combine original history with new history
total_acc = acc + new_history.history['accuracy']
total_loss = loss + new_history.history['loss']
total_val_acc = val_acc + new_history.history['val_accuracy']
total_val_loss = val_loss + new_history.history['val_loss']
print(len(total_acc))
print(total_acc)
# make plots
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(total_acc, label='Training Accuracy')
plt.plot(total_val_acc, label='Validation Accuracy')
plt.plot([initial_epochs-1, initial_epochs-1],
plt.ylim(), label='Start Fine Tuning') # creates boundary line between `model_2` and `model_3`
plt.legend(loc='lower right')
plt.title('Training and Validation accuracy')
plt.subplot(2,1,2)
plt.plot(total_loss, label='Training Loss')
plt.plot(total_val_loss, label='Validation Loss')
plt.plot([initial_epochs-1, initial_epochs-1],
plt.ylim(), label='Start Fine Tuning') # creates boundary line between `model_2` and `model_3`
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
This is where saving the history variables of our model training comes handy. Let's see what happened after fine-tuning the last 10 layers of our model.
compare_history(original_history=history_10_percent_data_aug,
new_history=history_fine_10_percent_data_aug,
initial_epochs=5)
5 11 [0.31200000643730164, 0.6079999804496765, 0.718666672706604, 0.7306666374206543, 0.762666642665863, 0.753333330154419, 0.7799999713897705, 0.7786666750907898, 0.8146666884422302, 0.8013333082199097, 0.8133333325386047]
The model flattened out quite quickly, however some amount of progress is still made.
Model 4: Fine-tuning an existing model all of the data¶
Let' try fine tuning but with entire dataset of 10 food classes.
# setup data directories
train_dir = '10_food_classes_all_data/train/'
test_dir = test_dir
# how many images are we working with?
walk_through_dir('10_food_classes_all_data'), walk_through_dir('10_food_classes_10_percent/test')
There are 1 directories and 0 images in '10_food_classes_all_data'. There are 10 directories and 0 images in '10_food_classes_all_data\train'. There are 0 directories and 750 images in '10_food_classes_all_data\train\ice_cream'. There are 0 directories and 750 images in '10_food_classes_all_data\train\chicken_curry'. There are 0 directories and 750 images in '10_food_classes_all_data\train\steak'. There are 0 directories and 750 images in '10_food_classes_all_data\train\sushi'. There are 0 directories and 750 images in '10_food_classes_all_data\train\chicken_wings'. There are 0 directories and 750 images in '10_food_classes_all_data\train\grilled_salmon'. There are 0 directories and 750 images in '10_food_classes_all_data\train\hamburger'. There are 0 directories and 750 images in '10_food_classes_all_data\train\pizza'. There are 0 directories and 750 images in '10_food_classes_all_data\train\ramen'. There are 0 directories and 750 images in '10_food_classes_all_data\train\fried_rice'. There are 10 directories and 0 images in '10_food_classes_10_percent/test'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\ice_cream'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\chicken_curry'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\steak'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\sushi'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\chicken_wings'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\grilled_salmon'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\hamburger'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\pizza'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\ramen'. There are 0 directories and 250 images in '10_food_classes_10_percent/test\fried_rice'.
(None, None)
# setup data inputs
import tensorflow as tf
IMG_SIZE = (224,224)
train_data_10_classes_full = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
label_mode='categorical',
image_size=IMG_SIZE)
# note: this is the same test we've been using for the previous modelling experiments
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
label_mode='categorical',
image_size=IMG_SIZE)
Found 7500 files belonging to 10 classes. Found 2500 files belonging to 10 classes.
We got the same test set, but now 10x the training data.
Our model_2 has been fine tuned on 10 percent of data, so to begin fine-tuning on all the data and keep experiments consistent, we need to recert it back to weights we checkpointed at epoch 5, during feature extraction.
To demonstate this, we'll first evaluate the current model_2.
# evaluate model (this is the fine tuned 10 percent)
model_2.evaluate(test_data)
79/79 ━━━━━━━━━━━━━━━━━━━━ 24s 307ms/step - accuracy: 0.8460 - loss: 0.5296
[0.529642641544342, 0.8460000157356262]
These are the same values as results_fine_tune_10_percent (as we've created a variable to it when evaluating it)
results_fine_tune_10_percent
[0.5296427607536316, 0.8460000157356262]
To keep our experiments clean, we'll create a new instance of model_2 using create_base_model() function.
More specifically, we're trying to measure:
- Experiment 3 (previous one) -
model_2with 10 layers fine-tuned for 5 more epochs on 10% of the data. - Experiment 4 (this one) -
model_2with layers fined-tuned for 5 more epochs on 100% of the data.
Both experiments, are using the same test data to allow comparison of results.
Plus, they should start from the same checkpoint (from model_2's feature extractor trained for 5 epochs on 10% of the data).
Let's first create new instance of model_2.
# create a new instance of model_2 for Experiment 4
model_2 = create_base_model(learning_rate=0.0001) # x10 lower learning rate for fine-tuning, like experiment 3
And now o make sure it starts at the same checkpoint, we can load the checkpointed weights from checkpoint_path.
# load previously checkpointed weights
model_2.load_weights(checkpoint_path)
x:\miniconda3\envs\tfenv\lib\site-packages\keras\src\saving\saving_lib.py:797: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 6 variables. saveable.load_own_variables(weights_store.get(inner_path))
Let's now get a summary and check how many trainable variables there are.
model_2.summary()
Model: "functional_7"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 224, 224, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ data_augmentation (Sequential) │ (None, None, None, 3) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ efficientnetv2-b0 (Functional) │ (None, None, None, │ 5,919,312 │ │ │ 1280) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling_layer │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ output_layer (Dense) │ (None, 10) │ 12,810 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 5,932,122 (22.63 MB)
Trainable params: 12,810 (50.04 KB)
Non-trainable params: 5,919,312 (22.58 MB)
print(len(model_2.trainable_variables))
2
This should now be the same as the original model_2 before fine tuning.
Results should be the same as results_10_percent_data_aug.
# After loading in weights, the values should have gone back down (due to no fine-tuning).
model_2.evaluate(test_data)
79/79 ━━━━━━━━━━━━━━━━━━━━ 29s 311ms/step - accuracy: 0.8164 - loss: 0.7025
[0.7025160193443298, 0.8163999915122986]
So the steps we've taken to reach here are:
- Trained a feature extraction transfer learning model for 5 epochs on 10% of the data (base model entirely frozen), and saved a checkpoint of the model's weights .
- Fine-tuned the same model on the same 10% of data, but for continuing on another 5 epochs, with the top 10 layers of base model being unfrozened.
- Saved the results and training logs per epoch.
- Reloaded
model_2from when it did 5 epochs, and did step 2 with 100% of data, instead of 10%.
So we'll fine-tune the top 10 layers again for another 5 epochs on 100% of the data. But first, let's remind ourselves which layers are trainable.
# check which layers are tuneable in the whole model
for layer_number, layer in enumerate(model_2.layers):
print(layer_number, layer.name, layer.trainable)
0 input_layer True 1 data_augmentation True 2 efficientnetv2-b0 False 3 global_average_pooling_layer True 4 output_layer True
Remember that our base model model_2, can be referenced via it's index locality such as model_2.layers[2].
So let's unfreeze the last 10 layers of base_model to make them trainable (aka. fine tune).
# unfreeze the top 10 layers in model_2's base model
model_2_base_model = model_2.layers[2]
model_2_base_model.trainable = True
# freeze all layers except for the last 10
for layer in model_2_base_model.layers[:-10]:
layer.trainable = False
Now let's make sure the right layers are trainable.
Note: You could experiment which number of layers should be trainable here. Generally, the more data you have, the more layers you can fine-tune to account for it.
# Check which layers are tuneable in the base model
for layer_number, layer in enumerate(model_2_base_model.layers):
print(layer_number, layer.name, layer.trainable)
0 input_layer_10 False 1 rescaling_7 False 2 normalization_7 False 3 stem_conv False 4 stem_bn False 5 stem_activation False 6 block1a_project_conv False 7 block1a_project_bn False 8 block1a_project_activation False 9 block2a_expand_conv False 10 block2a_expand_bn False 11 block2a_expand_activation False 12 block2a_project_conv False 13 block2a_project_bn False 14 block2b_expand_conv False 15 block2b_expand_bn False 16 block2b_expand_activation False 17 block2b_project_conv False 18 block2b_project_bn False 19 block2b_drop False 20 block2b_add False 21 block3a_expand_conv False 22 block3a_expand_bn False 23 block3a_expand_activation False 24 block3a_project_conv False 25 block3a_project_bn False 26 block3b_expand_conv False 27 block3b_expand_bn False 28 block3b_expand_activation False 29 block3b_project_conv False 30 block3b_project_bn False 31 block3b_drop False 32 block3b_add False 33 block4a_expand_conv False 34 block4a_expand_bn False 35 block4a_expand_activation False 36 block4a_dwconv2 False 37 block4a_bn False 38 block4a_activation False 39 block4a_se_squeeze False 40 block4a_se_reshape False 41 block4a_se_reduce False 42 block4a_se_expand False 43 block4a_se_excite False 44 block4a_project_conv False 45 block4a_project_bn False 46 block4b_expand_conv False 47 block4b_expand_bn False 48 block4b_expand_activation False 49 block4b_dwconv2 False 50 block4b_bn False 51 block4b_activation False 52 block4b_se_squeeze False 53 block4b_se_reshape False 54 block4b_se_reduce False 55 block4b_se_expand False 56 block4b_se_excite False 57 block4b_project_conv False 58 block4b_project_bn False 59 block4b_drop False 60 block4b_add False 61 block4c_expand_conv False 62 block4c_expand_bn False 63 block4c_expand_activation False 64 block4c_dwconv2 False 65 block4c_bn False 66 block4c_activation False 67 block4c_se_squeeze False 68 block4c_se_reshape False 69 block4c_se_reduce False 70 block4c_se_expand False 71 block4c_se_excite False 72 block4c_project_conv False 73 block4c_project_bn False 74 block4c_drop False 75 block4c_add False 76 block5a_expand_conv False 77 block5a_expand_bn False 78 block5a_expand_activation False 79 block5a_dwconv2 False 80 block5a_bn False 81 block5a_activation False 82 block5a_se_squeeze False 83 block5a_se_reshape False 84 block5a_se_reduce False 85 block5a_se_expand False 86 block5a_se_excite False 87 block5a_project_conv False 88 block5a_project_bn False 89 block5b_expand_conv False 90 block5b_expand_bn False 91 block5b_expand_activation False 92 block5b_dwconv2 False 93 block5b_bn False 94 block5b_activation False 95 block5b_se_squeeze False 96 block5b_se_reshape False 97 block5b_se_reduce False 98 block5b_se_expand False 99 block5b_se_excite False 100 block5b_project_conv False 101 block5b_project_bn False 102 block5b_drop False 103 block5b_add False 104 block5c_expand_conv False 105 block5c_expand_bn False 106 block5c_expand_activation False 107 block5c_dwconv2 False 108 block5c_bn False 109 block5c_activation False 110 block5c_se_squeeze False 111 block5c_se_reshape False 112 block5c_se_reduce False 113 block5c_se_expand False 114 block5c_se_excite False 115 block5c_project_conv False 116 block5c_project_bn False 117 block5c_drop False 118 block5c_add False 119 block5d_expand_conv False 120 block5d_expand_bn False 121 block5d_expand_activation False 122 block5d_dwconv2 False 123 block5d_bn False 124 block5d_activation False 125 block5d_se_squeeze False 126 block5d_se_reshape False 127 block5d_se_reduce False 128 block5d_se_expand False 129 block5d_se_excite False 130 block5d_project_conv False 131 block5d_project_bn False 132 block5d_drop False 133 block5d_add False 134 block5e_expand_conv False 135 block5e_expand_bn False 136 block5e_expand_activation False 137 block5e_dwconv2 False 138 block5e_bn False 139 block5e_activation False 140 block5e_se_squeeze False 141 block5e_se_reshape False 142 block5e_se_reduce False 143 block5e_se_expand False 144 block5e_se_excite False 145 block5e_project_conv False 146 block5e_project_bn False 147 block5e_drop False 148 block5e_add False 149 block6a_expand_conv False 150 block6a_expand_bn False 151 block6a_expand_activation False 152 block6a_dwconv2 False 153 block6a_bn False 154 block6a_activation False 155 block6a_se_squeeze False 156 block6a_se_reshape False 157 block6a_se_reduce False 158 block6a_se_expand False 159 block6a_se_excite False 160 block6a_project_conv False 161 block6a_project_bn False 162 block6b_expand_conv False 163 block6b_expand_bn False 164 block6b_expand_activation False 165 block6b_dwconv2 False 166 block6b_bn False 167 block6b_activation False 168 block6b_se_squeeze False 169 block6b_se_reshape False 170 block6b_se_reduce False 171 block6b_se_expand False 172 block6b_se_excite False 173 block6b_project_conv False 174 block6b_project_bn False 175 block6b_drop False 176 block6b_add False 177 block6c_expand_conv False 178 block6c_expand_bn False 179 block6c_expand_activation False 180 block6c_dwconv2 False 181 block6c_bn False 182 block6c_activation False 183 block6c_se_squeeze False 184 block6c_se_reshape False 185 block6c_se_reduce False 186 block6c_se_expand False 187 block6c_se_excite False 188 block6c_project_conv False 189 block6c_project_bn False 190 block6c_drop False 191 block6c_add False 192 block6d_expand_conv False 193 block6d_expand_bn False 194 block6d_expand_activation False 195 block6d_dwconv2 False 196 block6d_bn False 197 block6d_activation False 198 block6d_se_squeeze False 199 block6d_se_reshape False 200 block6d_se_reduce False 201 block6d_se_expand False 202 block6d_se_excite False 203 block6d_project_conv False 204 block6d_project_bn False 205 block6d_drop False 206 block6d_add False 207 block6e_expand_conv False 208 block6e_expand_bn False 209 block6e_expand_activation False 210 block6e_dwconv2 False 211 block6e_bn False 212 block6e_activation False 213 block6e_se_squeeze False 214 block6e_se_reshape False 215 block6e_se_reduce False 216 block6e_se_expand False 217 block6e_se_excite False 218 block6e_project_conv False 219 block6e_project_bn False 220 block6e_drop False 221 block6e_add False 222 block6f_expand_conv False 223 block6f_expand_bn False 224 block6f_expand_activation False 225 block6f_dwconv2 False 226 block6f_bn False 227 block6f_activation False 228 block6f_se_squeeze False 229 block6f_se_reshape False 230 block6f_se_reduce False 231 block6f_se_expand False 232 block6f_se_excite False 233 block6f_project_conv False 234 block6f_project_bn False 235 block6f_drop False 236 block6f_add False 237 block6g_expand_conv False 238 block6g_expand_bn False 239 block6g_expand_activation False 240 block6g_dwconv2 False 241 block6g_bn False 242 block6g_activation False 243 block6g_se_squeeze False 244 block6g_se_reshape False 245 block6g_se_reduce False 246 block6g_se_expand False 247 block6g_se_excite False 248 block6g_project_conv False 249 block6g_project_bn False 250 block6g_drop False 251 block6g_add False 252 block6h_expand_conv False 253 block6h_expand_bn False 254 block6h_expand_activation False 255 block6h_dwconv2 False 256 block6h_bn False 257 block6h_activation False 258 block6h_se_squeeze False 259 block6h_se_reshape False 260 block6h_se_reduce True 261 block6h_se_expand True 262 block6h_se_excite True 263 block6h_project_conv True 264 block6h_project_bn True 265 block6h_drop True 266 block6h_add True 267 top_conv True 268 top_bn True 269 top_activation True
Looks like our 10 last layers are trainable.
Now to recompile the model, so that our changes to the model is set.
# recompile the model (always recompile after any adjustments to a model)
model_2.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), # lr 10x lower for better fine tuning
metrics=['accuracy'])
Alright, time to fine-tune data :)
# continue to train and fine-tune the model to our data
fine_tune_epochs = initial_epochs+5
history_fine_10_classes_full = model_2.fit(train_data_10_classes_full,
epochs=fine_tune_epochs,
initial_epoch=history_10_percent_data_aug.epoch[-1],
validation_data=test_data,
validation_steps=int(0.25 * len(test_data)),
callbacks=[create_tensorboard_callback('transfer_learning','full_10_classes_fine_tune_last_10')])
Saving TensorBoard log files to: transfer_learning/full_10_classes_fine_tune_last_10/20251006-220237 Epoch 5/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 139s 543ms/step - accuracy: 0.7387 - loss: 0.9286 - val_accuracy: 0.8520 - val_loss: 0.4477 Epoch 6/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 143s 607ms/step - accuracy: 0.7877 - loss: 0.7272 - val_accuracy: 0.8832 - val_loss: 0.3469 Epoch 7/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 132s 564ms/step - accuracy: 0.8053 - loss: 0.6441 - val_accuracy: 0.8947 - val_loss: 0.3118 Epoch 8/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 131s 560ms/step - accuracy: 0.8119 - loss: 0.6017 - val_accuracy: 0.9145 - val_loss: 0.2662 Epoch 9/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 133s 568ms/step - accuracy: 0.8260 - loss: 0.5621 - val_accuracy: 0.9112 - val_loss: 0.2741 Epoch 10/10 235/235 ━━━━━━━━━━━━━━━━━━━━ 132s 563ms/step - accuracy: 0.8320 - loss: 0.5356 - val_accuracy: 0.9112 - val_loss: 0.2673
Note: Training took wayyy longer, because we are using 10x more data for training
Let's evaluate the test data
results_fine_tune_full_data = model_2.evaluate(test_data)
results_fine_tune_full_data
79/79 ━━━━━━━━━━━━━━━━━━━━ 31s 391ms/step - accuracy: 0.9156 - loss: 0.2604
[0.26041892170906067, 0.9156000018119812]
results_fine_tune_10_percent
[0.5296427607536316, 0.8460000157356262]
It looks like our fine-tuning on full training data has given the model's accuracy a bit of a boost!
# how did fine-tuning go with more data?
compare_history(original_history=history_10_percent_data_aug,
new_history=history_fine_10_classes_full,
initial_epochs=5)
5 11 [0.31200000643730164, 0.6079999804496765, 0.718666672706604, 0.7306666374206543, 0.762666642665863, 0.7386666536331177, 0.7877333164215088, 0.8053333163261414, 0.8118666410446167, 0.8259999752044678, 0.8320000171661377]
Seems like that extra data has helped edge out a bit! If trained for long, the model will most likely keep improving.
🛠 Exercises¶
- Write a function to visualize an image from any dataset (train or test file) and any class (e.g. "steak", "pizza"... etc), visualize it and make a prediction on it using a trained model.
- Use feature-extraction to train a transfer learning model on 10% of the Food Vision data for 10 epochs using tf.keras.applications.efficientnet_v2.EfficientNetV2B0 as the base model. Use the ModelCheckpoint callback to save the weights to file.
- Fine-tune the last 20 layers of the base model you trained in 2 for another 10 epochs. How did it go?
- Fine-tune the last 30 layers of the base model you trained in 2 for another 10 epochs. How did it go?
📖 Extra-curriculum¶
- Read the documentation on data augmentation in TensorFlow.
- Read the ULMFit paper (technical) for an introduction to the concept of freezing and unfreezing different layers.
- Read up on learning rate scheduling (there's a TensorFlow callback for this), how could this influence our model training?
- If you're training for longer, you probably want to reduce the learning rate as you go... the closer you get to the bottom of the hill, the smaller steps you want to take. Imagine it like finding a coin at the bottom of your couch. In the beginning your arm movements are going to be large and the closer you get, the smaller your movements become.